diff --git a/src/main/scala/ABMatrixReg.scala b/src/main/scala/ABMatrixReg.scala index d6bf3b2..f8ffd74 100644 --- a/src/main/scala/ABMatrixReg.scala +++ b/src/main/scala/ABMatrixReg.scala @@ -14,7 +14,7 @@ import utility.sram.SRAMTemplate class ABMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val FromDataController = new ABDataControlMatrixRegIO - val FromMemoryLoader = new ABMemoryLoaderMatrixRegIO + val FromMemoryLoader = new ABMemoryLoaderMatrixRegIO } class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ @@ -22,7 +22,7 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ val MatrixRegIO = new ABMatrixRegIO }) - + // 读写优先级逻辑:写优先于读 // 目前在Loader、DataController里面都加了FIFO,能保证一些堵的情况的发生 @@ -33,11 +33,12 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // MemoryLoader端的信号 val MemoryLoaderBankAddr = io.MatrixRegIO.FromMemoryLoader.BankAddr val MemoryLoaderData = io.MatrixRegIO.FromMemoryLoader.Data - + val MemoryLoaderByteMask = io.MatrixRegIO.FromMemoryLoader.ByteMask + // 写优先的MatrixReg控制逻辑 // write_go: 只要有写入请求(正常写或零填充)就为true val write_go = MemoryLoaderBankAddr.zip(MemoryLoaderData).map{case (a, b) => a.valid && b.valid}.reduce(_||_) - + // read_go: 只有在没有写请求时才允许读,实现写优先 val read_go = io.MatrixRegIO.FromDataController.BankAddr.valid && !MemoryLoaderBankAddr.map(_.valid).reduce(_||_) && @@ -53,13 +54,11 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // 实例化多个SRAM作为多个bank val sram_banks = (0 until ABMatrixRegNBanks) map { i => - // 使用SRAMTemplate替代SyncReadMem - // singlePort=true: 单端口SRAM,支持读写冲突处理 - // latency=1: 读延迟为1拍 + // Use byte-wide ways so SRAMTemplate waymask becomes byte write enable. val bank = Module(new SRAMTemplate( - gen = UInt((ABMatrixRegEntryByteSize*8).W), + gen = UInt(8.W), set = ABMatrixRegBankNEntries, - way = 1, + way = ABMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -89,7 +88,8 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ // 写数据逻辑 val s0_bank_write_addr = MemoryLoaderBankAddr(i).bits val s0_bank_write_data = MemoryLoaderData(i).bits - val s0_bank_write_valid = MemoryLoaderBankAddr(i).valid && MemoryLoaderData(i).valid + val s0_bank_write_mask = MemoryLoaderByteMask(i).bits + val s0_bank_write_valid = MemoryLoaderBankAddr(i).valid && MemoryLoaderData(i).valid && MemoryLoaderByteMask(i).valid // 最终的写入控制 val s0_final_write_valid = write_go && s0_bank_write_valid @@ -99,8 +99,8 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ when(write_go && s0_bank_write_valid){ if (YJPDebugEnable) { - printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x\n", - scp_id.U, i.U, s0_bank_write_addr, s0_bank_write_data) + printf("[ABMatrixReg_Write(%d)]Bank(%d): s0_bank_write_addr = %d, s0_bank_write_data = %x, mask = %x\n", + scp_id.U, i.U, s0_bank_write_addr, s0_bank_write_data, s0_bank_write_mask) } } @@ -110,15 +110,14 @@ class ABMatrixReg(scp_id: Int)(implicit p: Parameters) extends CuteModule{ bank.io.r.req.valid := bank_read_valid bank.io.r.req.bits.setIdx := s0_bank_read_addr // 读响应在下一拍返回(latency=1) - s1_bank_read_data := bank.io.r.resp.data(0) + s1_bank_read_data := bank.io.r.resp.data.asUInt // 连接SRAMTemplate的写接口 - // 使用apply方法设置写请求,way=1时waymask为None bank.io.w.req.valid := s0_final_write_valid bank.io.w.req.bits.setIdx := s0_final_write_addr - bank.io.w.req.bits.data(0) := s0_final_write_data + bank.io.w.req.bits.waymask.get := s0_bank_write_mask + bank.io.w.req.bits.data := s0_final_write_data.asTypeOf(Vec(ABMatrixRegEntryByteSize, UInt(8.W))) bank } } - diff --git a/src/main/scala/AMemoryLoader.scala b/src/main/scala/AMemoryLoader.scala index 02cddaa..dcd4539 100644 --- a/src/main/scala/AMemoryLoader.scala +++ b/src/main/scala/AMemoryLoader.scala @@ -12,6 +12,254 @@ import org.chipsalliance.cde.config._ class ASourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntries).W) + val MatrixRegisTail = Bool() + val BeatIndex = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) +} + +class TransDataPacket(implicit p: Parameters) extends CuteBundle { + val data = UInt(8.W) + val mask = Bool() + val entry_offset = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) +} + +class TransAlignPipe(bankId: Int)(implicit p: Parameters) extends CuteModule { + private val transLoadSize = Trans_Load_Size + private val transLoadSizeBits = log2Ceil(transLoadSize) + private val entryOffsetBits = log2Ceil(ABMatrixRegEntryByteSize) + private val routerLatency = 3 + private val drainCycles = transLoadSize + 2 + routerLatency // 两级 barrel shifter 和 OOORouter 三级流水都需要排空。 + + val io = IO(new Bundle { + val in_data = Input(UInt(outsideDataWidth.W)) + val in_mask = Input(UInt(outsideDataWidthByte.W)) + val resp_beat_cnt = Input(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val entry_offset = Input(UInt(entryOffsetBits.W)) + val debug_time = Input(UInt(64.W)) + val is_drain_trigger = Input(Bool()) + val in_valid = Input(Bool()) + val bus_stall = Output(Bool()) + val empty = Output(Bool()) + val out = Decoupled(Vec(transLoadSize, new TransDataPacket)) + }) + + require(isPow2(transLoadSize), "Trans_Load_Size must be power of 2") + require(transLoadSize == 8 || transLoadSize == 16, "Trans_Load_Size currently supports 8 or 16") + + val s_normal :: s_drain :: Nil = Enum(2) + val state = RegInit(s_normal) + val drain_cnt = RegInit(0.U(log2Ceil(drainCycles + 1).W)) + + val out_valid_now = Wire(Bool()) + + val inBytes = io.in_data.asTypeOf(Vec(outsideDataWidthByte, UInt(8.W))) + val inMaskBits = io.in_mask.asBools + val bankData = VecInit((0 until transLoadSize).map { i => + val byteIdx = bankId + i * ABMatrixRegNBanks + Mux(inMaskBits(byteIdx), inBytes(byteIdx), 0.U(8.W)) + }) + // Tail/full mask is consumed here only to zero invalid payload bytes. The + // downstream write mask is regenerated from BeatIndex. + + def rotateLeft[T <: Data](data: Vec[T], amount: UInt): Vec[T] = { + VecInit((0 until transLoadSize).map { i => + val idx = (i.U(transLoadSizeBits.W) + amount.pad(transLoadSizeBits))(transLoadSizeBits - 1, 0) + data(idx) + }) + } + + val shift_amt = io.resp_beat_cnt(transLoadSizeBits - 1, 0) + val shift_low = shift_amt(1, 0) + val shift_high_amt = if (transLoadSizeBits > 2) Cat(shift_amt(transLoadSizeBits - 1, 2), 0.U(2.W)) else 0.U(transLoadSizeBits.W) + + val coarseData = rotateLeft(bankData, shift_high_amt) + + val stage1Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(8.W)))) + val stage1Valid = RegInit(false.B) + val stage1ShiftLow = RegInit(0.U(2.W)) + val stage1EntryOffset = RegInit(0.U(entryOffsetBits.W)) + + val stage2Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(8.W)))) + val stage2Valid = RegInit(false.B) + val stage2EntryOffset = RegInit(0.U(entryOffsetBits.W)) + + val fineData = rotateLeft(stage1Data, stage1ShiftLow) + + val laneData = Seq.tabulate(transLoadSize) { i => + RegInit(VecInit(Seq.fill(i + 1)(0.U(8.W)))) + } + // laneValid only marks that a response packet occupies this triangular + // pipeline slot. Tail/full masks have already zeroed invalid payload bytes; + // SRAM write masks are regenerated later from BeatIndex. + val laneValid = Seq.tabulate(transLoadSize) { i => + RegInit(VecInit(Seq.fill(i + 1)(false.B))) + } + // offset 与最后一级 lane 的物理位置绑定,而不是一拍内所有 lane 共享。 + // 新输入的 offset 进入 offset(0),随后沿 offset(0)->offset(1)->... 移动, + // 对齐“同一笔输入数据依次出现在最后一级第 0、1、...、N-1 个位置”的时空轨迹。 + val entryOffsetPipe = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(entryOffsetBits.W)))) + + val laneValidBits = laneValid.zipWithIndex.flatMap { case (validVec, laneIdx) => + (0 to laneIdx).map(validVec(_)) + } + val any_valid_in_lane_pipe = laneValidBits.reduce(_ || _) + val any_valid_in_pipe = stage1Valid || stage2Valid || any_valid_in_lane_pipe + val input_accept = state === s_normal && io.in_valid + val pipe_advance = input_accept || (state === s_drain) + + if (bankId == 0) { + when(input_accept) { + printf("[AML_TRANS_ALIGN_IN<%d>] beatCnt:%d shift:%d shiftHigh:%d shiftLow:%d entry:%d raw:%x mask:%x bankData:%x coarse:%x drainTrig:%d\n", + io.debug_time, io.resp_beat_cnt, shift_amt, shift_high_amt, shift_low, + io.entry_offset, io.in_data, io.in_mask, bankData.asUInt, + coarseData.asUInt, io.is_drain_trigger.asUInt) + } + } + + // MatrixReg 写侧恒 ready,因此转置对齐流水线不再等待末端 valid/ready。 + // drain 期间仍反压总线响应,防止下一组数据在旧组 router 三级流水未排空前进入。 + io.bus_stall := state === s_drain + io.empty := state === s_normal && !any_valid_in_pipe + + when(pipe_advance) { + stage1Data := coarseData + stage1Valid := input_accept + stage1ShiftLow := shift_low + stage1EntryOffset := io.entry_offset + + stage2Data := fineData + stage2Valid := stage1Valid + stage2EntryOffset := stage1EntryOffset + + for (i <- 0 until transLoadSize) { + laneData(i)(0) := stage2Data(i) + laneValid(i)(0) := stage2Valid + for (j <- 1 to i) { + laneData(i)(j) := laneData(i)(j - 1) + laneValid(i)(j) := laneValid(i)(j - 1) + } + } + + entryOffsetPipe(0) := Mux(stage2Valid, stage2EntryOffset, 0.U) + for (i <- 1 until transLoadSize) { + entryOffsetPipe(i) := entryOffsetPipe(i - 1) + } + } + + val outPackets = Wire(Vec(transLoadSize, new TransDataPacket)) + val outValids = Wire(Vec(transLoadSize, Bool())) + for (i <- 0 until transLoadSize) { + val depthIdx = i + outPackets(i).data := laneData(i)(depthIdx) + outPackets(i).mask := laneValid(i)(depthIdx) + outPackets(i).entry_offset := entryOffsetPipe(i) + outValids(i) := laneValid(i)(depthIdx) + } + out_valid_now := outValids.asUInt.orR + + io.out.bits := outPackets + io.out.valid := out_valid_now && pipe_advance + + if (bankId == 0) { + val outDataUInt = VecInit((0 until transLoadSize).map(i => outPackets(i).data)).asUInt + val outEntryUInt = VecInit((0 until transLoadSize).map(i => outPackets(i).entry_offset)).asUInt + when(io.out.valid) { + printf("[AML_TRANS_ALIGN_OUT<%d>] valids:%b data:%x entry:%x\n", + io.debug_time, outValids.asUInt, outDataUInt, outEntryUInt) + } + } + + switch(state) { + is(s_normal) { + when(input_accept && io.is_drain_trigger) { + state := s_drain + drain_cnt := (drainCycles - 1).U + } + } + is(s_drain) { + when(drain_cnt === 0.U) { + state := s_normal + }.otherwise { + drain_cnt := drain_cnt - 1.U + } + } + } +} + +class OOORouter(implicit p: Parameters) extends CuteModule { + private val transLoadSize = Trans_Load_Size + private val entryOffsetBits = log2Ceil(ABMatrixRegEntryByteSize) + private val groupSize = math.max(1, transLoadSize / 4) + + val io = IO(new Bundle { + val in = Flipped(Decoupled(Vec(transLoadSize, new TransDataPacket))) + val final_data = Output(UInt(ABMatrixRegEntryBitSize.W)) + val final_mask = Output(UInt(ABMatrixRegEntryByteSize.W)) + val valid = Output(Bool()) + val empty = Output(Bool()) + }) + + require(ABMatrixRegEntryByteSize == 32, "OOORouter currently targets 32B AB MatrixReg entries") + require(transLoadSize == 8 || transLoadSize == 16, "OOORouter currently supports Trans_Load_Size 8 or 16") + + io.in.ready := true.B + + val inputFire = io.in.valid && io.in.ready + + // Byte mask 与 data 使用同一套 offset 译码和规约树。 + // 不保留历史 life mask,避免在当前拍没有有效 data 的 byte 上误拉高写使能。 + val stage1Data = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(32.W)))) + val stage1Mask = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(4.W)))) + val stage1WordIdx = RegInit(VecInit(Seq.fill(transLoadSize)(0.U(3.W)))) + val stage1Valid = RegInit(VecInit(Seq.fill(transLoadSize)(false.B))) + val stage1AnyValid = RegInit(false.B) + + val stage2Data = RegInit(VecInit(Seq.fill(4)(0.U(ABMatrixRegEntryBitSize.W)))) + val stage2Mask = RegInit(VecInit(Seq.fill(4)(0.U(ABMatrixRegEntryByteSize.W)))) + val stage2Valid = RegInit(false.B) + + val stage3Data = RegInit(0.U(ABMatrixRegEntryBitSize.W)) + val stage3Mask = RegInit(0.U(ABMatrixRegEntryByteSize.W)) + val stage3Valid = RegInit(false.B) + + for (i <- 0 until transLoadSize) { + val pkt = io.in.bits(i) + val byteInWord = pkt.entry_offset(1, 0) + val wordIdx = pkt.entry_offset(entryOffsetBits - 1, 2) + stage1Data(i) := Mux(inputFire && pkt.mask, (pkt.data.pad(32) << (byteInWord << 3))(31, 0), 0.U) + stage1Mask(i) := Mux(inputFire && pkt.mask, UIntToOH(byteInWord, 4).asUInt, 0.U) + stage1WordIdx(i) := wordIdx + stage1Valid(i) := inputFire && pkt.mask + } + val inputAnyValid = io.in.bits.map(_.mask).reduce(_ || _) + stage1AnyValid := inputFire && inputAnyValid + + val expandedData = (0 until transLoadSize).map { i => + Mux(stage1Valid(i), (stage1Data(i).asUInt << (stage1WordIdx(i) << 5))(ABMatrixRegEntryBitSize - 1, 0), 0.U(ABMatrixRegEntryBitSize.W)) + } + val expandedMask = (0 until transLoadSize).map { i => + Mux(stage1Valid(i), (stage1Mask(i).asUInt << (stage1WordIdx(i) << 2))(ABMatrixRegEntryByteSize - 1, 0), 0.U(ABMatrixRegEntryByteSize.W)) + } + + for (g <- 0 until 4) { + val start = g * groupSize + val end = math.min(start + groupSize, transLoadSize) + val dataTerms = (start until end).map(expandedData) + val maskTerms = (start until end).map(expandedMask) + val reducedData = if (dataTerms.nonEmpty) dataTerms.reduce(_ | _) else 0.U(ABMatrixRegEntryBitSize.W) + val reducedMask = if (maskTerms.nonEmpty) maskTerms.reduce(_ | _) else 0.U(ABMatrixRegEntryByteSize.W) + stage2Data(g) := reducedData + stage2Mask(g) := reducedMask + } + stage2Valid := stage1AnyValid + + stage3Data := stage2Data.reduce(_ | _) + stage3Mask := stage2Mask.reduce(_ | _) + stage3Valid := stage2Valid + + io.final_data := stage3Data + io.final_mask := stage3Mask + io.valid := stage3Valid + io.empty := !stage1AnyValid && !stage2Valid && !stage3Valid } class AMemoryLoader(implicit p: Parameters) extends CuteModule{ @@ -28,8 +276,11 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.Data.map(_.valid := false.B) io.ToMatrixRegIO.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid + io.LocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.LocalMMUIO.Response.ready := false.B io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B @@ -52,7 +303,6 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ difftestAmuFinish.valid := (io.ToMatrixRegIO.BankAddr.map(_.valid).reduce(_||_) || (io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady)) difftestAmuFinish.pc := pcReg - // DiffAmuFinishEvent packing is parameterized by words-per-bank. val eventWordsPerBank = difftestAmuFinish.data.length / ABMatrixRegNBanks val abMRegWordsPerBank = ABMatrixRegEntryBitSize / 64 require(difftestAmuFinish.data.length % ABMatrixRegNBanks == 0, "DiffAmuFinishEvent.data should divide by AB bank count") @@ -61,6 +311,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits + difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits for (w <- 0 until eventWordsPerBank) { if (w < abMRegWordsPerBank) { val lo = w * 64 @@ -77,24 +328,34 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val ConfigInfo = io.ConfigInfo val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) val ApplicationTensor_A_Stride_M = RegInit(0.U(MMUAddrWidth.W)) - val dataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val K_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_M = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val MatrixRegTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val Conherent = RegInit(true.B) + val Is_Transpose = RegInit(false.B) val Is_ZeroLoad = RegInit(false.B) val Is_FullLoad = RegInit(false.B) val s_idle :: s_mm_task :: Nil = Enum(2) val state = RegInit(s_idle) - val s_load_idle :: s_load_init :: s_load_working :: s_load_end :: Nil = Enum(4) + val s_load_idle :: s_load_init :: s_load_working :: s_load_quiesce :: s_load_end :: Nil = Enum(5) val memoryload_state = RegInit(s_load_idle) + private val transposeEndDrainCycles = Trans_Load_Size + 2 + 3 + val transposeEndDrainCnt = RegInit(0.U(log2Ceil(transposeEndDrainCycles + 1).W)) - val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize)+1).W)) + val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*outsideDataWidthByte)+1).W)) val TotalRequestSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) val CurrentLoaded_BlockTensor_M_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val CurrentLoaded_BlockTensor_K_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val Request_M_Iter_Time = RegInit(0.U(log2Ceil(Matrix_MN).W)) + val Request_M_Iter_Time = RegInit(0.U(log2Ceil(math.max(Matrix_MN, ABMatrixRegEntryByteSize)).W)) + + val group_req_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_resp_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_size_reg = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + + private val transBaseAddrBits = log2Ceil(ABMatrixRegBankNEntries) val SoureceIdSearchTable = RegInit(VecInit(Seq.fill(SoureceMaxNum)(0.U((new ASourceIdSearch).getWidth.W)))) val MaxRequestIter = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) @@ -102,6 +363,7 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntries).W))))) val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W))))) + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U) val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free) val MReg_Fill_Table_Not_Full = MReg_Fill_Table_Free.reduce(_ || _) @@ -112,15 +374,43 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ val Bank_Fill_Search_FIFO_Tail = RegInit((VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(log2Ceil(AMemoryLoaderReadFromMemoryFIFODepth).W))))) val Bank_Fill_Search_FIFO_Full = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(false.B))) val Bank_Fill_Search_FIFO_Empty = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(true.B))) - val Bank_Fill_Valid = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(false.B))) - val Have_Bank_Fill = Bank_Fill_Valid.reduce(_ || _) for(i <- 0 until ABMatrixRegNBanks){ Bank_Fill_Search_FIFO_Full(i) := Bank_Fill_Search_FIFO_Tail(i) === WrapInc(Bank_Fill_Search_FIFO_Head(i), AMemoryLoaderReadFromMemoryFIFODepth) Bank_Fill_Search_FIFO_Empty(i) := Bank_Fill_Search_FIFO_Head(i) === Bank_Fill_Search_FIFO_Tail(i) - Bank_Fill_Valid(i) := Bank_Fill_Search_FIFO_Head(i) =/= Bank_Fill_Search_FIFO_Tail(i) } + val transAlignPipes = Seq.tabulate(ABMatrixRegNBanks)(i => Module(new TransAlignPipe(i))) + val transRouters = Seq.tabulate(ABMatrixRegNBanks)(_ => Module(new OOORouter)) + val transPipeInValid = WireInit(false.B) + val transPipeInData = WireInit(0.U(outsideDataWidth.W)) + val transPipeInMask = WireInit(0.U(outsideDataWidthByte.W)) + val transPipeRespBeatCnt = WireInit(0.U(group_resp_cnt.getWidth.W)) + val transPipeEntryOffset = WireInit(0.U(log2Ceil(ABMatrixRegEntryByteSize).W)) + val transPipeDrainTrigger = WireInit(false.B) + val transBusStall = transAlignPipes.map(_.io.bus_stall).reduce(_ || _) + val transWriteBaseAddr = RegInit(0.U(transBaseAddrBits.W)) + val transWriteAddrCnt = RegInit(0.U(log2Ceil(Trans_Load_Size).W)) + + for (i <- 0 until ABMatrixRegNBanks) { + transAlignPipes(i).io.in_data := transPipeInData + transAlignPipes(i).io.in_mask := transPipeInMask + transAlignPipes(i).io.resp_beat_cnt := transPipeRespBeatCnt + transAlignPipes(i).io.entry_offset := transPipeEntryOffset + transAlignPipes(i).io.debug_time := io.DebugInfo.DebugTimeStampe + transAlignPipes(i).io.is_drain_trigger := transPipeDrainTrigger + transAlignPipes(i).io.in_valid := transPipeInValid + transRouters(i).io.in <> transAlignPipes(i).io.out + } + val transAlignEmpty = transAlignPipes.map(_.io.empty).reduce(_ && _) + val transRouterEmpty = transRouters.map(_.io.empty).reduce(_ && _) + val transPipelineEmpty = transAlignEmpty && transRouterEmpty + val transRouterValidVec = VecInit(transRouters.map(_.io.valid)).asUInt + val transRouterWriteValid = transRouters.map(_.io.valid).reduce(_ || _) + val transWriteAddrOffset = transWriteAddrCnt * ReduceGroupSize.U + val transWriteAddrWide = transWriteBaseAddr + transWriteAddrOffset + val transWriteAddr = transWriteAddrWide(transBaseAddrBits - 1, 0) + val Request = io.LocalMMUIO.Request when(state === s_idle){ @@ -129,19 +419,22 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ state := s_mm_task memoryload_state := s_load_init MatrixRegTensor_M := ConfigInfo.MatrixRegTensor_M - MatrixRegTensor_K := ConfigInfo.MatrixRegTensor_K CurrentMatrixRegId := ConfigInfo.MatrixRegId Tensor_Block_BaseAddr := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr ApplicationTensor_A_Stride_M := ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_Stride_M - dataType := ConfigInfo.ApplicationTensor_A.dataType + HasTail := ConfigInfo.ApplicationTensor_A.HasTail + TailByteMask := ConfigInfo.ApplicationTensor_A.TailByteMask + K_Beat_Count := ConfigInfo.ApplicationTensor_A.K_Beat_Count Is_ZeroLoad := ConfigInfo.LoadTaskInfo.Is_ZeroLoad Is_FullLoad := ConfigInfo.LoadTaskInfo.Is_FullLoad Conherent := ConfigInfo.Conherent + Is_Transpose := ConfigInfo.Is_Transpose if(YJPAMLDebugEnable){ - printf("[AML<%d>]AMemoryLoader Task Start, MatrixRegTensor_M:%d, MatrixRegTensor_K:%d, BaseVaddr:%x, Stride_M:%x, dataType:%d, Is_ZeroLoad:%d, Is_FullLoad:%d\n", + printf("[AML<%d>]AMemoryLoader Task Start, MatrixRegTensor_M:%d, MatrixRegTensor_K:%d, BaseVaddr:%x, Stride_M:%x, dataType:%d, Is_Transpose:%d, Is_ZeroLoad:%d, Is_FullLoad:%d\n", io.DebugInfo.DebugTimeStampe, ConfigInfo.MatrixRegTensor_M, ConfigInfo.MatrixRegTensor_K, ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr, ConfigInfo.ApplicationTensor_A.ApplicationTensor_A_Stride_M, - ConfigInfo.ApplicationTensor_A.dataType, ConfigInfo.LoadTaskInfo.Is_ZeroLoad.asUInt, ConfigInfo.LoadTaskInfo.Is_FullLoad.asUInt) + ConfigInfo.ApplicationTensor_A.dataType, ConfigInfo.Is_Transpose.asUInt, + ConfigInfo.LoadTaskInfo.Is_ZeroLoad.asUInt, ConfigInfo.LoadTaskInfo.Is_FullLoad.asUInt) } } } @@ -154,13 +447,20 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_K_Iter := 0.U Request_M_Iter_Time := 0.U - MaxRequestIter := MatrixRegTensor_M * MatrixRegTensor_K * ReduceWidthByte.U / outsideDataWidthByte.U + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + transposeEndDrainCnt := 0.U + transWriteBaseAddr := 0.U + transWriteAddrCnt := 0.U + MaxRequestIter := MatrixRegTensor_M * K_Beat_Count Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(AMemoryLoaderReadFromMemoryFIFODepth)(false.B)) } is(s_load_working) { io.ToMatrixRegIO.active := true.B @@ -174,6 +474,8 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr(i).valid := true.B io.ToMatrixRegIO.Data(i).bits := 0.U io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Fill(ABMatrixRegEntryByteSize, true.B) + io.ToMatrixRegIO.ByteMask(i).valid := true.B } TotalLoadSize := TotalLoadSize + 1.U if (YJPAMLDebugEnable) printf("[AML<%d>]ZeroLoad, TotalLoadSize: %d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize) @@ -184,30 +486,101 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ } when(Is_FullLoad){ - // 矩阵访存顺序:按 M 分 bank 交织,再扫 K。地址 = BaseAddr + M*Stride_M + K*ReduceWidthByte - val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % ABMatrixRegNBanks.U - val RequestMatrixRegAddr = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / ABMatrixRegNBanks.U * ReduceGroupSize.U + CurrentLoaded_BlockTensor_K_Iter + //先转换成独热码然后进行减一即可计算出掩码 + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val fullTaskMask = Fill(outsideDataWidthByte, true.B) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_A_Stride_M + CurrentLoaded_BlockTensor_K_Iter * ReduceWidthByte.U + // Transpose request path reuses the normal iter registers: + // CurrentLoaded_BlockTensor_M_Iter -> large_M group index + // CurrentLoaded_BlockTensor_K_Iter -> K/M-beat index + // Request_M_Iter_Time -> small_M inside current group + val transpose_large_m_base = CurrentLoaded_BlockTensor_M_Iter * ABMatrixRegEntryByteSize.U + val transpose_current_m = transpose_large_m_base + Request_M_Iter_Time + val transpose_group_in_range = transpose_large_m_base < MatrixRegTensor_M + val transpose_group_remain = Mux(transpose_group_in_range, MatrixRegTensor_M - transpose_large_m_base, 0.U) + val current_group_size = Wire(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + current_group_size := Mux( + transpose_group_remain < ABMatrixRegEntryByteSize.U, + transpose_group_remain(current_group_size.getWidth - 1, 0), + ABMatrixRegEntryByteSize.U(current_group_size.getWidth.W) + ) + val group_has_no_requests = group_req_cnt === 0.U && group_resp_cnt === 0.U + val group_is_idle = group_has_no_requests && transPipelineEmpty + val active_group_size = Mux(group_has_no_requests, current_group_size, group_size_reg) + val transpose_group_can_issue = Mux( + group_has_no_requests, + group_is_idle && (current_group_size =/= 0.U), + group_req_cnt < active_group_size + ) + val transpose_req_enable = (TotalRequestSize < MaxRequestIter) && transpose_group_can_issue + + // 矩阵访存顺序:按 M 分 bank 交织,再扫 K。地址 = BaseAddr + M*Stride_M + K*64B + val RequestMatrixRegMIndex = Mux(Is_Transpose, transpose_current_m, CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) + val NormalRequestMatrixRegBankId = RequestMatrixRegMIndex % ABMatrixRegNBanks.U + val NormalRequestMatrixRegBaseAddr = ((RequestMatrixRegMIndex / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val NormalRequestMatrixRegAddr = NormalRequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + val TransposeRequestMatrixRegAddr = CurrentLoaded_BlockTensor_K_Iter * (Trans_Load_Size * ReduceGroupSize).U + CurrentLoaded_BlockTensor_M_Iter + val RequestMatrixRegBankId = Mux(Is_Transpose, 0.U, NormalRequestMatrixRegBankId) + val RequestMatrixRegAddr = Mux(Is_Transpose, TransposeRequestMatrixRegAddr, NormalRequestMatrixRegAddr) + + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + RequestMatrixRegMIndex * ApplicationTensor_A_Stride_M + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent, io.LocalMMUIO.ConherentRequsetSourceID, io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits Request.bits.RequestType_isWrite := false.B - Request.valid := (TotalRequestSize < MaxRequestIter) + Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + Request.valid := Mux(Is_Transpose, transpose_req_enable, TotalRequestSize < MaxRequestIter) when(Request.fire){ val TableItem = Wire(new ASourceIdSearch) TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail + TableItem.BeatIndex := Request_M_Iter_Time SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt - Request_M_Iter_Time := Request_M_Iter_Time + 1.U - when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ - Request_M_Iter_Time := 0.U - CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + (outsideDataWidthByte.U / ReduceWidthByte.U) - when(CurrentLoaded_BlockTensor_K_Iter + (outsideDataWidthByte.U / ReduceWidthByte.U) === MatrixRegTensor_K){ - CurrentLoaded_BlockTensor_K_Iter := 0.U - CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U + if (YJPAMLDebugEnable) { + printf("[AML_RequestHandshake<%d>] M_Iter:%d, K_Iter:%d, ReqTime:%d, Addr:%x, BankId:%d, RegAddr:%d, SourceId:%d, Tail:%d\n", + io.DebugInfo.DebugTimeStampe, CurrentLoaded_BlockTensor_M_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_M_Iter_Time, Request.bits.RequestVirtualAddr, RequestMatrixRegBankId, RequestMatrixRegAddr, sourceId.bits, RequestBeatIsTail) + } + + when(Is_Transpose) { + printf("[AML_TRANS_REQ<%d>] totalReq:%d groupReq:%d groupResp:%d idle:%d largeM:%d kBeat:%d beat:%d groupSize:%d activeGroupSize:%d regBase:%d vaddr:%x source:%d tail:%d\n", + io.DebugInfo.DebugTimeStampe, TotalRequestSize, group_req_cnt, group_resp_cnt, + group_is_idle.asUInt, CurrentLoaded_BlockTensor_M_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_M_Iter_Time, current_group_size, active_group_size, RequestMatrixRegAddr, + Request.bits.RequestVirtualAddr, sourceId.bits, RequestBeatIsTail.asUInt) + + val small_m_reach_group_boundary = Request_M_Iter_Time === (ABMatrixRegEntryByteSize - 1).U + val small_m_reach_tensor_boundary = transpose_current_m === (MatrixRegTensor_M - 1.U) + val small_m_wrap = small_m_reach_group_boundary || small_m_reach_tensor_boundary + val k_wrap = CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U) + + when(group_is_idle) { + group_size_reg := current_group_size + } + group_req_cnt := group_req_cnt + 1.U + + Request_M_Iter_Time := Request_M_Iter_Time + 1.U + when(small_m_wrap) { + Request_M_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(k_wrap) { + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + 1.U + } + } + }.otherwise { + Request_M_Iter_Time := Request_M_Iter_Time + 1.U + when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === MatrixRegTensor_M - 1.U){ + Request_M_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U + } } } when(TotalRequestSize =/= MaxRequestIter){ @@ -221,62 +594,191 @@ class AMemoryLoader(implicit p: Parameters) extends CuteModule{ } val current_fill_fifo_full = WireInit(false.B) - when(io.LocalMMUIO.Response.valid){ + when(io.LocalMMUIO.Response.valid && !Is_Transpose){ val respSourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(respSourceId).asTypeOf(new ASourceIdSearch).MatrixRegBankId current_fill_fifo_full := Bank_Fill_Search_FIFO_Full(MatrixRegBankId) } - io.LocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && !current_fill_fifo_full + io.LocalMMUIO.Response.ready := Mux(Is_Transpose, !transBusStall, MReg_Fill_Table_Not_Full && !current_fill_fifo_full) when(io.LocalMMUIO.Response.fire){ val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID - val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch).MatrixRegBankId - val MatrixRegAddr = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch).MatrixRegAddr + val MatrixRegSearch = SoureceIdSearchTable(sourceId).asTypeOf(new ASourceIdSearch) + val MatrixRegBankId = MatrixRegSearch.MatrixRegBankId + val MatrixRegAddr = MatrixRegSearch.MatrixRegAddr val ResponseData = io.LocalMMUIO.Response.bits.ReseponseData val FIFOIndex = Bank_Fill_Search_FIFO_Head(MatrixRegBankId) - MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData - MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr - MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U - Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index - Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), AMemoryLoaderReadFromMemoryFIFODepth) + if (YJPAMLDebugEnable) { + printf("[AML_ResponseHandshake<%d>] Data:%x, BankId:%d, RegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n", io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr, sourceId, FIFOIndex, MatrixRegSearch.MatrixRegisTail) + } + + when(Is_Transpose) { + val next_group_resp_cnt = group_resp_cnt + 1.U + val drain_trigger = next_group_resp_cnt === active_group_size + + printf("[AML_TRANS_RESP<%d>] source:%d data:%x tail:%d mask:%x base:%d beatIndex:%d respCnt:%d nextResp:%d activeGroupSize:%d drainTrig:%d writeBase:%d writeCnt:%d\n", + io.DebugInfo.DebugTimeStampe, sourceId, ResponseData, MatrixRegSearch.MatrixRegisTail.asUInt, + Mux(MatrixRegSearch.MatrixRegisTail, tailTaskMask, fullTaskMask), MatrixRegAddr, + MatrixRegSearch.BeatIndex, group_resp_cnt, next_group_resp_cnt, active_group_size, + drain_trigger.asUInt, transWriteBaseAddr, transWriteAddrCnt) + + transPipeInValid := true.B + transPipeInData := ResponseData + transPipeInMask := Mux(MatrixRegSearch.MatrixRegisTail, tailTaskMask, fullTaskMask) + transPipeRespBeatCnt := group_resp_cnt + transPipeEntryOffset := MatrixRegSearch.BeatIndex + transPipeDrainTrigger := drain_trigger + + when(group_resp_cnt === 0.U) { + transWriteBaseAddr := MatrixRegAddr + transWriteAddrCnt := 0.U + } + + when(next_group_resp_cnt === active_group_size) { + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + }.otherwise { + group_resp_cnt := next_group_resp_cnt + } + }.otherwise { + MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData + MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr + MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := MatrixRegSearch.MatrixRegisTail + Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index + Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), AMemoryLoaderReadFromMemoryFIFODepth) + } if (YJPAMLDebugEnable){ printf("[AML<%d>]Response, Data:%x, BankId:%d, RegAddr:%d\n", io.DebugInfo.DebugTimeStampe, ResponseData, MatrixRegBankId, MatrixRegAddr) } } - val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) - for (i <- 0 until ABMatrixRegNBanks){ - when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ - val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) - Current_Fill_MReg_Time(i) := 1.U - val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) - FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) - io.ToMatrixRegIO.BankAddr(i).valid := true.B - io.ToMatrixRegIO.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) - io.ToMatrixRegIO.Data(i).valid := true.B - MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U - when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ - Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), AMemoryLoaderReadFromMemoryFIFODepth) + when(Is_Transpose) { + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPAMLDebugEnable) { + when(routerValid) { + printf("[AML_TransposeWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } + } + } + when(transRouterWriteValid) { + printf("[AML_TRANS_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x totalLoad:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + TotalLoadSize, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + // 转置路径以真实写侧 valid 计数;任务结束只看请求、响应和流水线是否全部静默。 + val Current_Load_Fill_Size = transRouterWriteValid.asUInt + val nextTotalLoadSize = TotalLoadSize + Current_Load_Fill_Size + val transposeDone = TotalRequestSize === MaxRequestIter && + group_req_cnt === 0.U && group_resp_cnt === 0.U && + transPipelineEmpty + TotalLoadSize := nextTotalLoadSize + if(YJPAMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[AML_TransposeLoad<%d>]TotalLoadSize:%d, FillSize:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size) } - if (YJPAMLDebugEnable){ - printf("[AML<%d>]Fill bank:%d, RegAddr:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + } + when(transposeDone){ + memoryload_state := s_load_quiesce + transposeEndDrainCnt := (transposeEndDrainCycles - 1).U + if (YJPAMLDebugEnable) printf("[AML<%d>]TransposeFullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + }.otherwise { + val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) + for (i <- 0 until ABMatrixRegNBanks){ + when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ + val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val fillLowHalf = fillSlot(0) === 0.U + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + Current_Fill_MReg_Time(i) := 1.U + val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) + FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) + io.ToMatrixRegIO.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot + io.ToMatrixRegIO.BankAddr(i).valid := true.B + io.ToMatrixRegIO.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) + io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + io.ToMatrixRegIO.ByteMask(i).valid := true.B + if (YJPAMLDebugEnable) { + printf("[AML_MRegWriteHandshake<%d>] bank:%d, RegAddr:%x, WriteAddr:%x, Data:%x, ByteMask:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), io.ToMatrixRegIO.BankAddr(i).bits, io.ToMatrixRegIO.Data(i).bits, io.ToMatrixRegIO.ByteMask(i).bits, MReg_Fill_Table_Time(CurrentFIFOIndex)) + } + MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U + when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ + Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), AMemoryLoaderReadFromMemoryFIFODepth) + } + if (YJPAMLDebugEnable){ + printf("[AML<%d>]Fill bank:%d, RegAddr:%x, Time:%d\n", io.DebugInfo.DebugTimeStampe, i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + } } } - } - val Current_Load_Fill_Size = PopCount(Current_Fill_MReg_Time.asUInt) - TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size - if(YJPAMLDebugEnable){ - when(Current_Load_Fill_Size =/= 0.U) { - printf("[AML<%d>]TotalLoadSize:%d, FillSize:%d, Max:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size, MaxRequestIter * MAX_Fill_Times.U) + val Current_Load_Fill_Size = PopCount(Current_Fill_MReg_Time.asUInt) + TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size + if(YJPAMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[AML<%d>]TotalLoadSize:%d, FillSize:%d, Max:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size, MaxRequestIter * MAX_Fill_Times.U) + } + } + when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ + memoryload_state := s_load_end + if (YJPAMLDebugEnable) printf("[AML<%d>]FullLoadEnd\n", io.DebugInfo.DebugTimeStampe) } } - when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ - memoryload_state := s_load_end - if (YJPAMLDebugEnable) printf("[AML<%d>]FullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + } + is(s_load_quiesce) { + io.ToMatrixRegIO.active := true.B + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPAMLDebugEnable) { + when(routerValid) { + printf("[AML_TransposeQuiesceWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } } } + when(transRouterWriteValid) { + printf("[AML_TRANS_QUIESCE_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x drainCnt:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + transposeEndDrainCnt, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + when(transposeEndDrainCnt === 0.U) { + memoryload_state := s_load_end + if (YJPAMLDebugEnable) printf("[AML<%d>]TransposeQuiesceEnd\n", io.DebugInfo.DebugTimeStampe) + }.otherwise { + transposeEndDrainCnt := transposeEndDrainCnt - 1.U + } } is(s_load_end) { ConfigInfo.MicroTaskEndValid := true.B diff --git a/src/main/scala/BMemoryLoader.scala b/src/main/scala/BMemoryLoader.scala index b68e11c..3281622 100644 --- a/src/main/scala/BMemoryLoader.scala +++ b/src/main/scala/BMemoryLoader.scala @@ -21,6 +21,8 @@ import org.chipsalliance.cde.config._ class BSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId = UInt(log2Ceil(ABMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(ABMatrixRegBankNEntries).W) + val MatrixRegisTail = Bool() + val BeatIndex = UInt(log2Ceil(ABMatrixRegEntryByteSize).W) } //对于卷积,数据摆放是[khkwoc][ic],对于矩阵乘,数据摆放是[N][K] @@ -41,8 +43,11 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.BankAddr.map(_.bits := DontCare) io.ToMatrixRegIO.Data.map(_.valid := false.B) io.ToMatrixRegIO.Data.map(_.bits := DontCare) + io.ToMatrixRegIO.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.ByteMask.map(_.bits := Fill(ABMatrixRegEntryByteSize, true.B)) io.LocalMMUIO.Request.valid := false.B io.LocalMMUIO.Request.bits := DontCare // It will be set if Request is valid + io.LocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.LocalMMUIO.Response.ready := false.B io.ConfigInfo.MicroTaskEndValid := false.B io.ConfigInfo.MicroTaskReady := false.B @@ -68,6 +73,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until ABMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.ToMatrixRegIO.BankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.ToMatrixRegIO.BankAddr(i).bits + difftestAmuFinish.bankMask(i) := io.ToMatrixRegIO.ByteMask(i).bits for (w <- 0 until eventWordsPerBank) { if (w < abMRegWordsPerBank) { val lo = w * 64 @@ -91,9 +97,11 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val MatrixRegTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val MatrixRegTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val K_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val dataType = RegInit(0.U(ElementDataType.DataTypeBitWidth.W)) val Tensor_B_BaseVaddr = RegInit(0.U(MMUAddrWidth.W)) - val ApplicationTensor_B_Stride_N = RegInit(0.U(MMUAddrWidth.W)) @@ -103,11 +111,14 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //访存状态机,用来配合流水线刷新 - val s_load_idle :: s_load_init :: s_load_working :: s_load_end :: Nil = Enum(4) + val s_load_idle :: s_load_init :: s_load_working :: s_load_quiesce :: s_load_end :: Nil = Enum(5) val memoryload_state = RegInit(s_load_idle) + private val transposeEndDrainCycles = Trans_Load_Size + 2 + 3 + val transposeEndDrainCnt = RegInit(0.U(log2Ceil(transposeEndDrainCycles + 1).W)) val Tensor_Block_BaseAddr = Reg(UInt(MMUAddrWidth.W)) //分块矩阵的基地址 val Conherent = RegInit(true.B) //是否一致性访存的标志位,由TaskController提供 + val Is_Transpose = RegInit(false.B) //是否转置load,由TaskController提供 //如果configinfo有效 @@ -126,11 +137,16 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ Tensor_B_BaseVaddr := io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr //这个不重要 Tensor_Block_BaseAddr := io.ConfigInfo.ApplicationTensor_B.BlockTensor_B_BaseVaddr //这个是关键 Conherent := io.ConfigInfo.Conherent + Is_Transpose := io.ConfigInfo.Is_Transpose + HasTail := io.ConfigInfo.ApplicationTensor_B.HasTail + TailByteMask := io.ConfigInfo.ApplicationTensor_B.TailByteMask + K_Beat_Count := io.ConfigInfo.ApplicationTensor_B.K_Beat_Count + dataType := io.ConfigInfo.ApplicationTensor_B.dataType ApplicationTensor_B_Stride_N := io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_Stride_N //下一个N,需要增加多少地址偏移量 if(YJPBMLDebugEnable) { printf("[BML<%d>]BMemoryLoader Task Start\n",io.DebugInfo.DebugTimeStampe) - printf("[BML<%d>]MatrixRegTensor_N:%d,MatrixRegTensor_K:%d\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.MatrixRegTensor_N,io.ConfigInfo.MatrixRegTensor_K) + printf("[BML<%d>]MatrixRegTensor_N:%d,MatrixRegTensor_K:%d,Is_Transpose:%d\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.MatrixRegTensor_N,io.ConfigInfo.MatrixRegTensor_K,io.ConfigInfo.Is_Transpose) printf("[BML<%d>]Tensor_B_BaseVaddr:%x,Tensor_Block_BaseAddr:%x\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr,io.ConfigInfo.ApplicationTensor_B.BlockTensor_B_BaseVaddr) printf("[BML<%d>]ApplicationTensor_B_Stride_N:%x\n",io.DebugInfo.DebugTimeStampe,io.ConfigInfo.ApplicationTensor_B.ApplicationTensor_B_Stride_N) } @@ -160,11 +176,17 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //如果是memoryload_state === s_load_working,那么我们就要开始取数 //如果是memoryload_state === s_load_end,那么我们就要结束取数 val TotalLoadSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)+1).W)) //总共要加载的数据量 - val CurrentLoaded_BlockTensor_N = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) - val CurrentLoaded_BlockTensor_K = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val TotalRequestSize = RegInit(0.U((log2Ceil(Tensor_MN*ReduceGroupSize*ReduceWidthByte)).W)) + val CurrentLoaded_BlockTensor_N_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val CurrentLoaded_BlockTensor_K_Iter = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) + val Request_N_Iter_Time = RegInit(0.U(log2Ceil(math.max(Matrix_MN, ABMatrixRegEntryByteSize)).W)) + + val group_req_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_resp_cnt = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + val group_size_reg = RegInit(0.U(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + + private val transBaseAddrBits = log2Ceil(ABMatrixRegBankNEntries) - val MaxBlockTensor_N_Index = MatrixRegTensor_N - val MaxBlockTensor_K_Index = MatrixRegTensor_K //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 //用sourceid做索引,存储MatrixReg的地址和bank号,是一组寄存器 @@ -176,6 +198,7 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(ABMatrixRegBankNEntries).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/ABMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 val MReg_Fill_Table_Valid = MReg_Fill_Table_Time.map(_ =/= 0.U)//记录这个FIFO里的数据是否有效 val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free)//返回第一个空位的index @@ -195,78 +218,167 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ Bank_Fill_Search_FIFO_Empty(i) := Bank_Fill_Search_FIFO_Head(i) === Bank_Fill_Search_FIFO_Tail(i)//这个bank不需要写scp } + val transAlignPipes = Seq.tabulate(ABMatrixRegNBanks)(i => Module(new TransAlignPipe(i))) + val transRouters = Seq.tabulate(ABMatrixRegNBanks)(_ => Module(new OOORouter)) + val transPipeInValid = WireInit(false.B) + val transPipeInData = WireInit(0.U(outsideDataWidth.W)) + val transPipeInMask = WireInit(0.U(outsideDataWidthByte.W)) + val transPipeRespBeatCnt = WireInit(0.U(group_resp_cnt.getWidth.W)) + val transPipeEntryOffset = WireInit(0.U(log2Ceil(ABMatrixRegEntryByteSize).W)) + val transPipeDrainTrigger = WireInit(false.B) + val transBusStall = transAlignPipes.map(_.io.bus_stall).reduce(_ || _) + val transWriteBaseAddr = RegInit(0.U(transBaseAddrBits.W)) + val transWriteAddrCnt = RegInit(0.U(log2Ceil(Trans_Load_Size).W)) + + for (i <- 0 until ABMatrixRegNBanks) { + transAlignPipes(i).io.in_data := transPipeInData + transAlignPipes(i).io.in_mask := transPipeInMask + transAlignPipes(i).io.resp_beat_cnt := transPipeRespBeatCnt + transAlignPipes(i).io.entry_offset := transPipeEntryOffset + transAlignPipes(i).io.debug_time := io.DebugInfo.DebugTimeStampe + transAlignPipes(i).io.is_drain_trigger := transPipeDrainTrigger + transAlignPipes(i).io.in_valid := transPipeInValid + transRouters(i).io.in <> transAlignPipes(i).io.out + } + val transAlignEmpty = transAlignPipes.map(_.io.empty).reduce(_ && _) + val transRouterEmpty = transRouters.map(_.io.empty).reduce(_ && _) + val transPipelineEmpty = transAlignEmpty && transRouterEmpty + val transRouterValidVec = VecInit(transRouters.map(_.io.valid)).asUInt + val transRouterWriteValid = transRouters.map(_.io.valid).reduce(_ || _) + val transWriteAddrOffset = transWriteAddrCnt * ReduceGroupSize.U + val transWriteAddrWide = transWriteBaseAddr + transWriteAddrOffset + val transWriteAddr = transWriteAddrWide(transBaseAddrBits - 1, 0) + val Request = io.LocalMMUIO.Request switch(memoryload_state) { is(s_load_init) { memoryload_state := s_load_working TotalLoadSize := 0.U - CurrentLoaded_BlockTensor_N := 0.U - CurrentLoaded_BlockTensor_K := 0.U - MaxRequestIter := MatrixRegTensor_K * MatrixRegTensor_N * ReduceWidthByte.U / (outsideDataWidthByte.U) //总共要发出的访存请求的次数 + TotalRequestSize := 0.U + CurrentLoaded_BlockTensor_N_Iter := 0.U + CurrentLoaded_BlockTensor_K_Iter := 0.U + Request_N_Iter_Time := 0.U + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + transposeEndDrainCnt := 0.U + transWriteBaseAddr := 0.U + transWriteAddrCnt := 0.U + MaxRequestIter := MatrixRegTensor_N * K_Beat_Count //总共要发出的访存请求的次数 + Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) + Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) + Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) + MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) + MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) + MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(BMemoryLoaderReadFromMemoryFIFODepth)(false.B)) } is(s_load_working) { io.ToMatrixRegIO.active := true.B //根据不同的MemoryOrder,执行不同的访存模式 - //只要Request是ready,我们发出的访存请求就会被MMU送往总线,我们可以发出下一个访存请求 - //不用担心乘法电路延迟,再不济,可以提前几个周期将乘法结果算好,做成fifo送进来 - Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + (CurrentLoaded_BlockTensor_N * ApplicationTensor_B_Stride_N) + (CurrentLoaded_BlockTensor_K * ReduceWidthByte.U) + //先转换成独热码然后进行减一即可计算出掩码 + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val fullTaskMask = Fill(outsideDataWidthByte, true.B) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U)) + + // Transpose request path reuses the normal iter registers: + // CurrentLoaded_BlockTensor_N_Iter -> large_N group index + // CurrentLoaded_BlockTensor_K_Iter -> K/N-beat index + // Request_N_Iter_Time -> small_N inside current group + val transpose_large_n_base = CurrentLoaded_BlockTensor_N_Iter * ABMatrixRegEntryByteSize.U + val transpose_current_n = transpose_large_n_base + Request_N_Iter_Time + val transpose_group_in_range = transpose_large_n_base < MatrixRegTensor_N + val transpose_group_remain = Mux(transpose_group_in_range, MatrixRegTensor_N - transpose_large_n_base, 0.U) + val current_group_size = Wire(UInt(log2Ceil(ABMatrixRegEntryByteSize + 1).W)) + current_group_size := Mux( + transpose_group_remain < ABMatrixRegEntryByteSize.U, + transpose_group_remain(current_group_size.getWidth - 1, 0), + ABMatrixRegEntryByteSize.U(current_group_size.getWidth.W) + ) + val group_has_no_requests = group_req_cnt === 0.U && group_resp_cnt === 0.U + val group_is_idle = group_has_no_requests && transPipelineEmpty + val active_group_size = Mux(group_has_no_requests, current_group_size, group_size_reg) + val transpose_group_can_issue = Mux( + group_has_no_requests, + group_is_idle && (current_group_size =/= 0.U), + group_req_cnt < active_group_size + ) + val transpose_req_enable = (TotalRequestSize < MaxRequestIter) && transpose_group_can_issue + + // 访存顺序与AML保持一致:先沿N维分4个bank发射,再推进K维,最后推进下一组N block + val RequestMatrixRegNIndex = Mux(Is_Transpose, transpose_current_n, CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) + val NormalRequestMatrixRegBankId = RequestMatrixRegNIndex % ABMatrixRegNBanks.U + val NormalRequestMatrixRegBaseAddr = ((RequestMatrixRegNIndex / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + val NormalRequestMatrixRegAddr = NormalRequestMatrixRegBaseAddr + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(MAX_Fill_Times)) + val TransposeRequestMatrixRegAddr = CurrentLoaded_BlockTensor_K_Iter * (Trans_Load_Size * ReduceGroupSize).U + CurrentLoaded_BlockTensor_N_Iter + val RequestMatrixRegBankId = Mux(Is_Transpose, 0.U, NormalRequestMatrixRegBankId) + val RequestMatrixRegAddr = Mux(Is_Transpose, TransposeRequestMatrixRegAddr, NormalRequestMatrixRegAddr) + + Request.bits.RequestVirtualAddr := Tensor_Block_BaseAddr + RequestMatrixRegNIndex * ApplicationTensor_B_Stride_N + (CurrentLoaded_BlockTensor_K_Iter << log2Ceil(outsideDataWidthByte)) val sourceId = Mux(Conherent,io.LocalMMUIO.ConherentRequsetSourceID,io.LocalMMUIO.nonConherentRequsetSourceID) Request.bits.RequestConherent := Conherent Request.bits.RequestSourceID := sourceId.bits Request.bits.RequestType_isWrite := false.B - Request.valid := true.B - when(CurrentLoaded_BlockTensor_N === MaxBlockTensor_N_Index || CurrentLoaded_BlockTensor_K === MaxBlockTensor_K_Index)//Is_invalid_IH_IW时,不发出访存请求,尝试直接0填充 - { - Request.valid := false.B - } + Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + Request.valid := Mux(Is_Transpose, transpose_req_enable, TotalRequestSize < MaxRequestIter) - //数据在MatrixReg中的编排 - //数据会先排K,再排M - //AVector一定是不同M的数据,K不断送入,直到K迭代完成,再换新的M, - // K 0 1 2 3 4 5 6 7 time AVector MatrixRegData也这么排布 - // M 0 0 8 g o {bank[0] [1] [2] [3]} - // 0 0 1 2 3 4 5 6 7 1 1 9 h p |addr 0 | 0 8 g o - // 1 8 9 a b c d e f 2 2 a i q | 1 | 1 9 h p - // 2 g h i j k l m n 3 3 b j r | 2 | 2 a i q - // 3 o p g r s t u v 4 4 c k s | 3 | 3 b j r - // 4 w x y z ....... 5 5 d l t | 4 | 4 c k s - // 5 !.............. 6 6 e m u | 5 | 5 d l t - // 6 @.............. 7 7 f n v | 6 | 6 e m u - // 7 #.............. 8 w ! @ # | 7 | 7 f n v - // 8 $.............. 9 ....... | ........................... - // - // - // 在内存中的排布则是 0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z ....... - - when(Request.fire && sourceId.valid){//符合条件的话,这条访存请求一定会被发出 //Request.ready表明了LocalMMU会处理这条访存请求,sourceID valid,表明这条访存请求的sourceID是被LocalMMU认可有效才发送到这个模块的 val TableItem = Wire(new BSourceIdSearch) - TableItem.MatrixRegBankId := CurrentLoaded_BlockTensor_N % ABMatrixRegNBanks.U - TableItem.MatrixRegAddr := ((CurrentLoaded_BlockTensor_N / ABMatrixRegNBanks.U) * ReduceGroupSize.U) + CurrentLoaded_BlockTensor_K + TableItem.MatrixRegBankId := RequestMatrixRegBankId + TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail + TableItem.BeatIndex := Request_N_Iter_Time SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt - if (YJPBMLDebugEnable) - { - //输出id和request的信息 - printf("[BML<%d>]sourceId:%d,MatrixRegBankId:%d,MatrixRegAddr:%d\n",io.DebugInfo.DebugTimeStampe,sourceId.bits,TableItem.MatrixRegBankId,TableItem.MatrixRegAddr) - //输出这次request的信息 - printf("[BML<%d>]RequestVirtualAddr:%x,RequestConherent:%d,RequestSourceID:%d,RequestType_isWrite:%d\n",io.DebugInfo.DebugTimeStampe,Request.bits.RequestVirtualAddr,Request.bits.RequestConherent,Request.bits.RequestSourceID,Request.bits.RequestType_isWrite) + if (YJPBMLDebugEnable) { + printf("[BML_RequestHandshake<%d>] sourceId:%d, MatrixRegBankId:%d, MatrixRegAddr:%d, RequestVirtualAddr:%x, RequestConherent:%d, RequestType_isWrite:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,sourceId.bits,TableItem.MatrixRegBankId,TableItem.MatrixRegAddr,Request.bits.RequestVirtualAddr,Request.bits.RequestConherent,Request.bits.RequestType_isWrite,RequestBeatIsTail) } - when(CurrentLoaded_BlockTensor_N < MaxBlockTensor_N_Index){ - when(CurrentLoaded_BlockTensor_K + MAX_Fill_Times.U < MaxBlockTensor_K_Index){ - //根据不同的内存Order,计算出访存请求的地址 - CurrentLoaded_BlockTensor_K := CurrentLoaded_BlockTensor_K + MAX_Fill_Times.U - }.otherwise{ - CurrentLoaded_BlockTensor_K := 0.U - CurrentLoaded_BlockTensor_N := CurrentLoaded_BlockTensor_N + 1.U + when(Is_Transpose) { + printf("[BML_TRANS_REQ<%d>] totalReq:%d groupReq:%d groupResp:%d idle:%d largeN:%d kBeat:%d beat:%d groupSize:%d activeGroupSize:%d regBase:%d vaddr:%x source:%d tail:%d\n", + io.DebugInfo.DebugTimeStampe, TotalRequestSize, group_req_cnt, group_resp_cnt, + group_is_idle.asUInt, CurrentLoaded_BlockTensor_N_Iter, CurrentLoaded_BlockTensor_K_Iter, + Request_N_Iter_Time, current_group_size, active_group_size, RequestMatrixRegAddr, + Request.bits.RequestVirtualAddr, sourceId.bits, RequestBeatIsTail.asUInt) + + val small_n_reach_group_boundary = Request_N_Iter_Time === (ABMatrixRegEntryByteSize - 1).U + val small_n_reach_tensor_boundary = transpose_current_n === (MatrixRegTensor_N - 1.U) + val small_n_wrap = small_n_reach_group_boundary || small_n_reach_tensor_boundary + val k_wrap = CurrentLoaded_BlockTensor_K_Iter === (K_Beat_Count - 1.U) + + when(group_is_idle) { + group_size_reg := current_group_size + } + group_req_cnt := group_req_cnt + 1.U + + Request_N_Iter_Time := Request_N_Iter_Time + 1.U + when(small_n_wrap) { + Request_N_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(k_wrap) { + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + 1.U + } + } + }.otherwise { + Request_N_Iter_Time := Request_N_Iter_Time + 1.U + when(Request_N_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_N_Iter + Request_N_Iter_Time) === MatrixRegTensor_N - 1.U){ + Request_N_Iter_Time := 0.U + CurrentLoaded_BlockTensor_K_Iter := CurrentLoaded_BlockTensor_K_Iter + 1.U + when(CurrentLoaded_BlockTensor_K_Iter + 1.U === K_Beat_Count){ + CurrentLoaded_BlockTensor_K_Iter := 0.U + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + Matrix_MN.U + } } } + when(TotalRequestSize =/= MaxRequestIter){ + TotalRequestSize := TotalRequestSize + 1.U + } } val current_fill_fifo_full = WireInit(false.B) - when(io.LocalMMUIO.Response.valid) + when(io.LocalMMUIO.Response.valid && !Is_Transpose) { val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegBankId @@ -275,43 +387,81 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ //接受访存的返回值 //一个cam来存储访存请求的source_id对应的MatrixReg的地址和bank号 //根据response的sourceid,找到对应的MatrixReg的Fill_Table的队伍头的索引,填充到Fill_Table中 - if (ABMLNeedMRegFillTable) - { - io.LocalMMUIO.Response.ready := MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) - } else - { - io.LocalMMUIO.Response.ready := true.B + val normalRespReady = if (ABMLNeedMRegFillTable) { + MReg_Fill_Table_Not_Full && (current_fill_fifo_full === false.B) + } else { + true.B } + io.LocalMMUIO.Response.ready := Mux(Is_Transpose, !transBusStall, normalRespReady) when(io.LocalMMUIO.Response.fire){ //Trick注意这个设计,是doublebuffer的,AB只能是doublebuffer,回数一定是不会堵的,而且我们有时间对数据进行压缩解压缩~ //如果要做release设计,要么数据位宽翻倍,腾出周期来使得有空泡能给写任务进行,要么就是数据位宽不变,将读写端口变成独立的读和独立的写端口 val sourceId = io.LocalMMUIO.Response.bits.ReseponseSourceID - val MatrixRegBankId = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegBankId - val MatrixRegAddr = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch).MatrixRegAddr + val searchEntry = SoureceIdSearchTable(sourceId).asTypeOf(new BSourceIdSearch) + val MatrixRegBankId = searchEntry.MatrixRegBankId + val MatrixRegAddr = searchEntry.MatrixRegAddr val ResponseData = io.LocalMMUIO.Response.bits.ReseponseData val FIFOIndex = Bank_Fill_Search_FIFO_Head(MatrixRegBankId)//该bank的fill_fifo_index,标注了它当前在fillfifo的哪个位置,我们一共有bank个fill_fifo - if (!ABMLNeedMRegFillTable) - { - TotalLoadSize := TotalLoadSize + 1.U - for (i <- 0 until ABMatrixRegNBanks) + if (YJPBMLDebugEnable) { + printf("[BML_ResponseHandshake<%d>] ResponseData:%x, MatrixRegBankId:%d, MatrixRegAddr:%d, SourceId:%d, FIFOIndex:%d, Tail:%d\n",io.DebugInfo.DebugTimeStampe,ResponseData,MatrixRegBankId,MatrixRegAddr,sourceId,FIFOIndex,searchEntry.MatrixRegisTail) + } + + when(Is_Transpose) { + val next_group_resp_cnt = group_resp_cnt + 1.U + val drain_trigger = next_group_resp_cnt === active_group_size + + printf("[BML_TRANS_RESP<%d>] source:%d data:%x tail:%d mask:%x base:%d beatIndex:%d respCnt:%d nextResp:%d activeGroupSize:%d drainTrig:%d writeBase:%d writeCnt:%d\n", + io.DebugInfo.DebugTimeStampe, sourceId, ResponseData, searchEntry.MatrixRegisTail.asUInt, + Mux(searchEntry.MatrixRegisTail, tailTaskMask, fullTaskMask), MatrixRegAddr, + searchEntry.BeatIndex, group_resp_cnt, next_group_resp_cnt, active_group_size, + drain_trigger.asUInt, transWriteBaseAddr, transWriteAddrCnt) + + transPipeInValid := true.B + transPipeInData := ResponseData + transPipeInMask := Mux(searchEntry.MatrixRegisTail, tailTaskMask, fullTaskMask) + transPipeRespBeatCnt := group_resp_cnt + transPipeEntryOffset := searchEntry.BeatIndex + transPipeDrainTrigger := drain_trigger + + when(group_resp_cnt === 0.U) { + transWriteBaseAddr := MatrixRegAddr + transWriteAddrCnt := 0.U + } + + when(next_group_resp_cnt === active_group_size) { + group_req_cnt := 0.U + group_resp_cnt := 0.U + group_size_reg := 0.U + }.otherwise { + group_resp_cnt := next_group_resp_cnt + } + }.otherwise { + if (!ABMLNeedMRegFillTable) { - when(MatrixRegBankId === i.U) + TotalLoadSize := TotalLoadSize + 1.U + for (i <- 0 until ABMatrixRegNBanks) { - io.ToMatrixRegIO.BankAddr(i).bits := MatrixRegAddr - io.ToMatrixRegIO.Data(i).bits := ResponseData - io.ToMatrixRegIO.BankAddr(i).valid := true.B - io.ToMatrixRegIO.Data(i).valid := true.B + when(MatrixRegBankId === i.U) + { + io.ToMatrixRegIO.BankAddr(i).bits := MatrixRegAddr + io.ToMatrixRegIO.Data(i).bits := ResponseData(255, 0) + io.ToMatrixRegIO.BankAddr(i).valid := true.B + io.ToMatrixRegIO.Data(i).valid := true.B + io.ToMatrixRegIO.ByteMask(i).bits := Mux(searchEntry.MatrixRegisTail, tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B)) + io.ToMatrixRegIO.ByteMask(i).valid := true.B + } } } - } - MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData - MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr - MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData + MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr + MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := searchEntry.MatrixRegisTail - Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index - Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), BMemoryLoaderReadFromMemoryFIFODepth) + Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index + Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), BMemoryLoaderReadFromMemoryFIFODepth) + } //需要一个fifo?TODO:需要fifo的设计是可能这里会堵,实际上我们满吞吐的doublebuff的设计,咱们这里是不会堵的,直接填就完事了?还是等总线上去握手? //MatrixReg->MemoryLoader->MMU->Memory Bus->Memory上的长组合逻辑链,可以实现一下,为后续的开发做准备 //否则就靠软件来保证数据流和访存流,保证访存流的稳定性,一定不会堵,就可以省下这个长组合逻辑的延迟? @@ -327,61 +477,152 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ } } - // Fill_Table的回填优先级最高,一旦有回填任务就立即执行 - val HasScarhpadWrite = Have_Bank_Fill - val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) - if (ABMLNeedMRegFillTable) - { - for (i <- 0 until ABMatrixRegNBanks){ - when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ - val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) - Current_Fill_MReg_Time(i) := 1.U - val MatrixRegWriteRequest = io.ToMatrixRegIO - val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) - FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) - MatrixRegWriteRequest.BankAddr(i).valid := true.B - MatrixRegWriteRequest.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) - MatrixRegWriteRequest.Data(i).valid := true.B - - MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U - when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ - Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), BMemoryLoaderReadFromMemoryFIFODepth) + when(Is_Transpose) { + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPBMLDebugEnable) { + when(routerValid) { + printf("[BML_TransposeWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) } - - if (YJPCMLDebugEnable) - { - //输出fill_time 和 fifoindex - printf("[BML BMemoryLoader_Load<%d>]bankid: %d,CurrentFIFOIndex %d,ScartchPadAddr: %x, MReg_Fill_Table_Time(CurrentFIFOIndex): %d\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) - printf("[BML BMemoryLoader_Load<%d>]bankid: %d,ScartchPadAddr: %x, BankAddr: %x, Data: %x\n", io.DebugInfo.DebugTimeStampe,i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits) + } + } + when(transRouterWriteValid) { + printf("[BML_TRANS_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x totalLoad:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + TotalLoadSize, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + val Current_Load_Fill_Size = transRouterWriteValid.asUInt + val nextTotalLoadSize = TotalLoadSize + Current_Load_Fill_Size + val transposeDone = TotalRequestSize === MaxRequestIter && + group_req_cnt === 0.U && group_resp_cnt === 0.U && + transPipelineEmpty + TotalLoadSize := nextTotalLoadSize + if(YJPBMLDebugEnable){ + when(Current_Load_Fill_Size =/= 0.U) { + printf("[BML_TransposeLoad<%d>]TotalLoadSize:%d, FillSize:%d\n", io.DebugInfo.DebugTimeStampe, TotalLoadSize, Current_Load_Fill_Size) + } + } + when(transposeDone){ + memoryload_state := s_load_quiesce + transposeEndDrainCnt := (transposeEndDrainCycles - 1).U + if (YJPBMLDebugEnable) printf("[BML<%d>]TransposeFullLoadEnd\n", io.DebugInfo.DebugTimeStampe) + } + }.otherwise { + // Fill_Table的回填优先级最高,一旦有回填任务就立即执行 + val HasScarhpadWrite = Have_Bank_Fill + val Current_Fill_MReg_Time = WireInit(VecInit(Seq.fill(ABMatrixRegNBanks)(0.U(1.W)))) + if (ABMLNeedMRegFillTable) + { + for (i <- 0 until ABMatrixRegNBanks){ + when(Bank_Fill_Search_FIFO_Empty(i) === false.B){ + val CurrentFIFOIndex = Bank_Fill_Search_FIFO(i)(Bank_Fill_Search_FIFO_Tail(i)) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val fillLowHalf = fillSlot(0) === 0.U + val fillSlotOH = UIntToOH(fillSlot, MAX_Fill_Times) + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + Current_Fill_MReg_Time(i) := 1.U + val MatrixRegWriteRequest = io.ToMatrixRegIO + val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*ABMatrixRegEntryByteSize).W))))) + FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) + MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot + MatrixRegWriteRequest.BankAddr(i).valid := true.B + MatrixRegWriteRequest.Data(i).bits := Mux(fillLowHalf, FIFOData(0), FIFOData(1)) + MatrixRegWriteRequest.Data(i).valid := true.B + MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail && fillSlotOH(1), tailTaskMask(63, 32), Mux(currentIsTail && fillSlotOH(0), tailTaskMask(31, 0), Fill(ABMatrixRegEntryByteSize, true.B))) + MatrixRegWriteRequest.ByteMask(i).valid := true.B + if (YJPBMLDebugEnable) { + printf("[BML_MRegWriteHandshake<%d>] bankid: %d, CurrentFIFOIndex: %d, ScartchPadAddr: %x, BankAddr: %x, Data: %x, ByteMask: %x\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits, MatrixRegWriteRequest.ByteMask(i).bits) + } + + MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U + when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ + Bank_Fill_Search_FIFO_Tail(i) := WrapInc(Bank_Fill_Search_FIFO_Tail(i), BMemoryLoaderReadFromMemoryFIFODepth) + } + + if (YJPBMLDebugEnable) + { + //输出fill_time 和 fifoindex + printf("[BML BMemoryLoader_Load<%d>]bankid: %d,CurrentFIFOIndex %d,ScartchPadAddr: %x, MReg_Fill_Table_Time(CurrentFIFOIndex): %d\n", io.DebugInfo.DebugTimeStampe,i.U, CurrentFIFOIndex, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MReg_Fill_Table_Time(CurrentFIFOIndex)) + printf("[BML BMemoryLoader_Load<%d>]bankid: %d,ScartchPadAddr: %x, BankAddr: %x, Data: %x\n", io.DebugInfo.DebugTimeStampe,i.U, MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex), MatrixRegWriteRequest.BankAddr(i).bits, MatrixRegWriteRequest.Data(i).bits) + } } } } - } - val Current_Load_Fill_Size = WireInit(0.U((log2Ceil(ABMatrixRegNBanks)+1).W)) - Current_Load_Fill_Size := PopCount(Current_Fill_MReg_Time.asUInt) + val Current_Load_Fill_Size = WireInit(0.U((log2Ceil(ABMatrixRegNBanks)+1).W)) + Current_Load_Fill_Size := PopCount(Current_Fill_MReg_Time.asUInt) - if (ABMLNeedMRegFillTable) - { - TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size - } - if (YJPCMLDebugEnable) - { - when(Current_Load_Fill_Size =/= 0.U) + if (ABMLNeedMRegFillTable) { - printf("[BMemoryLoader_Load<%d>]Current_Load_Fill_Size: %d, TotalLoadSize: %d, MaxLoadSize: %d\n",io.DebugInfo.DebugTimeStampe, Current_Load_Fill_Size, TotalLoadSize, MaxRequestIter * MAX_Fill_Times.U) + TotalLoadSize := TotalLoadSize + Current_Load_Fill_Size } - } - //状态机切换 - when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ - memoryload_state := s_load_end - if (YJPCMLDebugEnable) + if (YJPBMLDebugEnable) { - printf("[BMemoryLoader_Load<%d>]LoadEnd\n",io.DebugInfo.DebugTimeStampe) + when(Current_Load_Fill_Size =/= 0.U) + { + printf("[BMemoryLoader_Load<%d>]Current_Load_Fill_Size: %d, TotalLoadSize: %d, MaxLoadSize: %d\n",io.DebugInfo.DebugTimeStampe, Current_Load_Fill_Size, TotalLoadSize, MaxRequestIter * MAX_Fill_Times.U) + } + } + //状态机切换 + when(TotalLoadSize === (MaxRequestIter * MAX_Fill_Times.U)){ + memoryload_state := s_load_end + if (YJPBMLDebugEnable) + { + printf("[BMemoryLoader_Load<%d>]LoadEnd\n",io.DebugInfo.DebugTimeStampe) + } } } } + is(s_load_quiesce) { + io.ToMatrixRegIO.active := true.B + for (i <- 0 until ABMatrixRegNBanks) { + val routerValid = transRouters(i).io.valid + io.ToMatrixRegIO.BankAddr(i).bits := transWriteAddr + io.ToMatrixRegIO.BankAddr(i).valid := routerValid + io.ToMatrixRegIO.Data(i).bits := transRouters(i).io.final_data + io.ToMatrixRegIO.Data(i).valid := routerValid + io.ToMatrixRegIO.ByteMask(i).bits := transRouters(i).io.final_mask + io.ToMatrixRegIO.ByteMask(i).valid := routerValid + if (YJPBMLDebugEnable) { + when(routerValid) { + printf("[BML_TransposeQuiesceWrite<%d>] bank:%d, Addr:%d, Data:%x, Mask:%x\n", + io.DebugInfo.DebugTimeStampe, i.U, io.ToMatrixRegIO.BankAddr(i).bits, + transRouters(i).io.final_data, transRouters(i).io.final_mask) + } + } + } + when(transRouterWriteValid) { + printf("[BML_TRANS_QUIESCE_WRITE<%d>] validVec:%b base:%d cnt:%d addr:%d bank0Data:%x bank0Mask:%x drainCnt:%d pipelineEmpty:%d\n", + io.DebugInfo.DebugTimeStampe, transRouterValidVec, transWriteBaseAddr, transWriteAddrCnt, + transWriteAddr, transRouters(0).io.final_data, transRouters(0).io.final_mask, + transposeEndDrainCnt, transPipelineEmpty.asUInt) + transWriteAddrCnt := Mux( + transWriteAddrCnt === (Trans_Load_Size - 1).U, + 0.U, + transWriteAddrCnt + 1.U + ) + } + when(transposeEndDrainCnt === 0.U) { + memoryload_state := s_load_end + if (YJPBMLDebugEnable) printf("[BML<%d>]TransposeQuiesceEnd\n", io.DebugInfo.DebugTimeStampe) + }.otherwise { + transposeEndDrainCnt := transposeEndDrainCnt - 1.U + } + } is(s_load_end) { io.ConfigInfo.MicroTaskEndValid := true.B when(io.ConfigInfo.MicroTaskEndValid && io.ConfigInfo.MicroTaskEndReady){ @@ -394,4 +635,4 @@ class BMemoryLoader(implicit p: Parameters) extends CuteModule{ } } } -} \ No newline at end of file +} diff --git a/src/main/scala/Bundles.scala b/src/main/scala/Bundles.scala index d0cb184..1f71137 100644 --- a/src/main/scala/Bundles.scala +++ b/src/main/scala/Bundles.scala @@ -93,7 +93,7 @@ object Bundles { val mtilem = Mtilex() // 36 : 28 val mtilen = Mtilex() // 27 : 19 val mtilek = Mtilex() // 18 : 10 - + // the type of source matrices // - lower 2 bits stands for the element width: // - 0: e8, 1: e16, 2: e32, 3: e4 @@ -134,7 +134,7 @@ object Bundles { val baseAddr = UInt(48.W) // 116 : 69 // the stride of the matrix val stride = UInt(48.W) // 68 : 21 - + // the number of rows of the matrix val row = Mtilex() // 20 : 12 // the number of columns of the matrix diff --git a/src/main/scala/CDataController.scala b/src/main/scala/CDataController.scala index 2111118..756c251 100644 --- a/src/main/scala/CDataController.scala +++ b/src/main/scala/CDataController.scala @@ -71,6 +71,7 @@ class CDataController(implicit p: Parameters) extends CuteModule{ for (i <- 0 until CMatrixRegNBanks) { difftestAmuFinish.bankValid(i) := io.FromMatrixRegIO.WriteBankAddr(i).valid difftestAmuFinish.bankAddr(i) := io.FromMatrixRegIO.WriteBankAddr(i).bits + difftestAmuFinish.bankMask(i) := Fill(CMatrixRegEntryByteSize, true.B) for (w <- 0 until eventWordsPerBank) { if (w < cMRegWordsPerBank) { val lo = w * 64 diff --git a/src/main/scala/CMatrixReg.scala b/src/main/scala/CMatrixReg.scala index e72af73..affe2f3 100644 --- a/src/main/scala/CMatrixReg.scala +++ b/src/main/scala/CMatrixReg.scala @@ -62,6 +62,7 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ val write_request_per_bank_addr = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U(CMatrixRegBankNEntries.W)))) val write_request_per_bank_data= WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(0.U((8*CMatrixRegEntryByteSize).W)))) + val write_request_per_bank_mask = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(Fill(CMatrixRegEntryByteSize, true.B)))) val write_request_per_bank_valid = WireInit(VecInit(Seq.fill(CMatrixRegNBanks)(false.B))) for( i <- 0 until CMatrixRegNBanks) { @@ -71,6 +72,7 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ write_request_per_bank_addr(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteBankAddr(i).bits, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.BankAddr(i).bits) write_request_per_bank_data(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteRequestData(i).bits, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).bits) + write_request_per_bank_mask(i) := Mux(decode_request.IsWriteFromDataController, Fill(CMatrixRegEntryByteSize, true.B), io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.ByteMask(i).bits) write_request_per_bank_valid(i) := Mux(decode_request.IsWriteFromDataController, io.MatrixRegIO.FromDataController.WriteRequestData(i).valid, io.MatrixRegIO.FromMemoryLoader.WriteRequestToMatrixReg.Data(i).valid) } @@ -81,9 +83,9 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ // 两个单口SRAM,奇偶地址各自负责,期望奇偶地址读写错开,奇读偶写,偶读奇写 val bankDepthHalf = (CMatrixRegBankNEntries + 1) / 2 val evenBank = Module(new SRAMTemplate( - gen = UInt((8*CMatrixRegEntryByteSize).W), + gen = UInt(8.W), set = bankDepthHalf, - way = 1, + way = CMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -92,9 +94,9 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ evenBank.suggestName("CUTE-C-MatrixReg-Even-SRAM") val oddBank = Module(new SRAMTemplate( - gen = UInt((8*CMatrixRegEntryByteSize).W), + gen = UInt(8.W), set = bankDepthHalf, - way = 1, + way = CMatrixRegEntryByteSize, singlePort = true, latency = 1, hasMbist = false, @@ -135,20 +137,22 @@ class CMatrixReg(scp_id:Int)(implicit p: Parameters) extends CuteModule{ oddBank.io.r.req.bits.setIdx := s0_read_idx // 读数据在下一拍返回(latency=1) - val even_read_data = evenBank.io.r.resp.data(0) - val odd_read_data = oddBank.io.r.resp.data(0) + val even_read_data = evenBank.io.r.resp.data.asUInt + val odd_read_data = oddBank.io.r.resp.data.asUInt s1_bank_read_data := Mux(s1_read_is_even, even_read_data, odd_read_data) //单口写路径:奇偶分流 // 偶地址SRAM写请求 evenBank.io.w.req.valid := s0_bank_write_valid && s0_write_is_even evenBank.io.w.req.bits.setIdx := s0_write_idx - evenBank.io.w.req.bits.data(0) := s0_bank_write_data + evenBank.io.w.req.bits.waymask.get := write_request_per_bank_mask(i) + evenBank.io.w.req.bits.data := s0_bank_write_data.asTypeOf(Vec(CMatrixRegEntryByteSize, UInt(8.W))) // 奇地址SRAM写请求 oddBank.io.w.req.valid := s0_bank_write_valid && !s0_write_is_even oddBank.io.w.req.bits.setIdx := s0_write_idx - oddBank.io.w.req.bits.data(0) := s0_bank_write_data + oddBank.io.w.req.bits.waymask.get := write_request_per_bank_mask(i) + oddBank.io.w.req.bits.data := s0_bank_write_data.asTypeOf(Vec(CMatrixRegEntryByteSize, UInt(8.W))) //单口SRAM读写不能同拍同时有效,分别对奇偶SRAM进行断言 val even_conflict = s0_bank_read_valid && s0_read_is_even && s0_bank_write_valid && s0_write_is_even diff --git a/src/main/scala/CMemoryLoader.scala b/src/main/scala/CMemoryLoader.scala index ce4c4fb..955bf87 100644 --- a/src/main/scala/CMemoryLoader.scala +++ b/src/main/scala/CMemoryLoader.scala @@ -19,6 +19,7 @@ import freechips.rocketchip.util.SeqToAugmentedSeq class CSourceIdSearch(implicit p: Parameters) extends CuteBundle{ val MatrixRegBankId =UInt(log2Ceil(CMatrixRegNBanks).W) val MatrixRegAddr = UInt(log2Ceil(CMatrixRegBankNEntries).W) + val MatrixRegisTail = Bool() } class CMemoryLoader(implicit p: Parameters) extends CuteModule{ @@ -46,10 +47,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ io.ToMatrixRegIO.StoreReadWriteRequest := 0.U io.LoadLocalMMUIO.Request.valid := false.B io.LoadLocalMMUIO.Request.bits := DontCare + io.LoadLocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.LoadLocalMMUIO.Response.ready := false.B io.StoreLocalMMUIO.Request.valid := false.B io.StoreLocalMMUIO.Request.bits := DontCare + io.StoreLocalMMUIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) io.StoreLocalMMUIO.Response.ready := false.B + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.valid := false.B) + io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask.map(_.bits := Fill(CMatrixRegEntryByteSize, true.B)) val CurrentLoadMatrixRegId = RegInit(0.U(CMatrixRegIdWidth.W)) val CurrentStoreMatrixRegId = RegInit(0.U(CMatrixRegIdWidth.W)) @@ -82,6 +87,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ for (i <- 0 until CMatrixRegNBanks) { difftestLoadFinish.bankValid(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).valid difftestLoadFinish.bankAddr(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(i).bits + difftestLoadFinish.bankMask(i) := io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(i).bits for (w <- 0 until eventWordsPerBank) { if (w < cMRegWordsPerBank) { val lo = w * 64 @@ -102,7 +108,6 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ difftestStoreFinish.pc := StorePcReg.get difftestStoreFinish.bankValid.foreach(_ := false.B) difftestStoreFinish.bankAddr.foreach(_ := 0.U) - difftestStoreFinish.data.foreach(_ := 0.U) difftestStoreFinish.finish := storeFinishAny } @@ -133,6 +138,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val ApplicationTensor_C_Stride_M = RegInit(0.U(MMUAddrWidth.W)) val ApplicationTensor_D_Stride_M = RegInit(0.U(MMUAddrWidth.W)) + val HasTail = RegInit(false.B) + val TailByteMask = RegInit(0.U(log2Ceil(outsideDataWidthByte + 1).W)) + val N_Beat_Count = RegInit(0.U(MatrixRegMaxTensorDimBitSize.W)) val Is_ZeroLoad = RegInit(false.B) val Is_FullLoad = RegInit(false.B) @@ -152,6 +160,9 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ IsLoadConherent := io.ConfigInfo.Conherent LoadMatrixRegTensor_M := io.ConfigInfo.MatrixRegTensor_M LoadMatrixRegTensor_N := io.ConfigInfo.MatrixRegTensor_N + HasTail := io.ConfigInfo.ApplicationTensor_C.HasTail + TailByteMask := io.ConfigInfo.ApplicationTensor_C.TailByteMask + N_Beat_Count := io.ConfigInfo.ApplicationTensor_C.N_Beat_Count Is_ZeroLoad := io.ConfigInfo.LoadTaskInfo.Is_ZeroLoad Is_FullLoad := io.ConfigInfo.LoadTaskInfo.Is_FullLoad @@ -225,6 +236,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val MReg_Fill_Table = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(outsideDataWidth.W))))) val MReg_Fill_Table_MReg_Addr = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U(log2Ceil(CMatrixRegBankNEntries).W)))))//记录这个LLC回的数是在scp的哪个地址 val MReg_Fill_Table_Time = RegInit((VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(0.U((log2Ceil(outsideDataWidthByte/CMatrixRegEntryByteSize)+1).W)))))//记录这个LLC回的数需要回填的次数,完成就可以将数据释放了 + val MReg_Fill_Table_IsTail = RegInit(VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(false.B))) val MReg_Fill_Table_Free = MReg_Fill_Table_Time.map(_ === 0.U)//记录这个FIFO能否能填数据 val MReg_Fill_Table_Valid = MReg_Fill_Table_Time.map(_ =/= 0.U)//记录这个FIFO里的数据是否有效 val MReg_Fill_Table_Insert_Index = PriorityEncoder(MReg_Fill_Table_Free)//返回第一个空位的index @@ -265,13 +277,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ TotalRequestSize := 0.U CurrentLoaded_BlockTensor_M_Iter := 0.U CurrentLoaded_BlockTensor_N_Iter := 0.U - MaxRequestIter := LoadMatrixRegTensor_M * LoadMatrixRegTensor_N * ResultWidthByte.U / (outsideDataWidthByte.U) //总共要发出的访存请求的次数 + MaxRequestIter := LoadMatrixRegTensor_M * N_Beat_Count //总共要发出的访存请求的次数 Bank_Fill_Search_FIFO := 0.U.asTypeOf(Bank_Fill_Search_FIFO) Bank_Fill_Search_FIFO_Head := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Head) Bank_Fill_Search_FIFO_Tail := 0.U.asTypeOf(Bank_Fill_Search_FIFO_Tail) MReg_Fill_Table := 0.U.asTypeOf(MReg_Fill_Table) MReg_Fill_Table_MReg_Addr := 0.U.asTypeOf(MReg_Fill_Table_MReg_Addr) MReg_Fill_Table_Time := 0.U.asTypeOf(MReg_Fill_Table_Time) + MReg_Fill_Table_IsTail := VecInit(Seq.fill(CMemoryLoaderReadFromMemoryFIFODepth)(false.B)) Request_M_Iter_Time := 0.U MReg_Fill_Table_Head := 0.U MReg_Fill_Table_Tail := 0.U @@ -322,10 +335,12 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ when(Is_FullLoad) { + val tailTaskMask = UIntToOH(TailByteMask, outsideDataWidthByte + 1).asUInt - 1.U(outsideDataWidthByte.W) + val RequestBeatIsTail = HasTail && (CurrentLoaded_BlockTensor_N_Iter === (N_Beat_Count - 1.U)) val RequestMatrixRegBankId = (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) % CMatrixRegNBanks.U //访存请求落在哪个MatrixRegBank上 - val RequestMatrixRegAddr = (CurrentLoaded_BlockTensor_M_Iter / CMatrixRegNBanks.U * LoadMatrixRegTensor_N / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter / Matrix_MN.U) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 + val RequestMatrixRegAddr = ((CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) / CMatrixRegNBanks.U) * (Tensor_MN.U / Matrix_MN.U) + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(MAX_Fill_Times)) //该访存请求的第零号数据,落在哪个MatrixRegBank的哪个地址上 - ReadRequest.bits.RequestVirtualAddr := LoadTensorBlockBaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + CurrentLoaded_BlockTensor_N_Iter * C_DataWidth + ReadRequest.bits.RequestVirtualAddr := LoadTensorBlockBaseAddr + (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) * ApplicationTensor_C_Stride_M + (CurrentLoaded_BlockTensor_N_Iter << log2Ceil(outsideDataWidthByte)) // val CurrentBankID = RequestMatrixRegBankId // val CurrentFIFOIndex = FromMemoryLoaderReadFIFOHead @@ -336,6 +351,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ ReadRequest.bits.RequestConherent := IsLoadConherent ReadRequest.bits.RequestSourceID := sourceId.bits ReadRequest.bits.RequestType_isWrite := false.B + ReadRequest.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) ReadRequest.valid := (TotalRequestSize < MaxRequestIter) //确定这个访存请求一定会发出 @@ -343,13 +359,14 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ val TableItem = Wire(new CSourceIdSearch) TableItem.MatrixRegBankId := RequestMatrixRegBankId TableItem.MatrixRegAddr := RequestMatrixRegAddr + TableItem.MatrixRegisTail := RequestBeatIsTail SoureceIdSearchTable(sourceId.bits) := TableItem.asUInt Request_M_Iter_Time := Request_M_Iter_Time + 1.U//连续的跨bank去访存 when(Request_M_Iter_Time === (Matrix_MN - 1).U || (CurrentLoaded_BlockTensor_M_Iter + Request_M_Iter_Time) === LoadMatrixRegTensor_M - 1.U){ Request_M_Iter_Time := 0.U - CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataWidth - when(CurrentLoaded_BlockTensor_N_Iter + outsideDataWidthByte.U / C_DataWidth === LoadMatrixRegTensor_N){ + CurrentLoaded_BlockTensor_N_Iter := CurrentLoaded_BlockTensor_N_Iter + 1.U + when(CurrentLoaded_BlockTensor_N_Iter + 1.U === N_Beat_Count){ CurrentLoaded_BlockTensor_N_Iter := 0.U CurrentLoaded_BlockTensor_M_Iter := CurrentLoaded_BlockTensor_M_Iter + Matrix_MN.U } @@ -397,6 +414,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ MReg_Fill_Table(MReg_Fill_Table_Insert_Index) := ResponseData MReg_Fill_Table_MReg_Addr(MReg_Fill_Table_Insert_Index) := MatrixRegAddr MReg_Fill_Table_Time(MReg_Fill_Table_Insert_Index) := MAX_Fill_Times.U + MReg_Fill_Table_IsTail(MReg_Fill_Table_Insert_Index) := SoureceIdSearchTable(sourceId).asTypeOf(new CSourceIdSearch).MatrixRegisTail Bank_Fill_Search_FIFO(MatrixRegBankId)(FIFOIndex) := MReg_Fill_Table_Insert_Index Bank_Fill_Search_FIFO_Head(MatrixRegBankId) := WrapInc(Bank_Fill_Search_FIFO_Head(MatrixRegBankId), CMemoryLoaderReadFromMemoryFIFODepth) @@ -419,11 +437,22 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ Current_Fill_MReg_Time(i) := 1.U val MatrixRegWriteRequest = io.ToMatrixRegIO.WriteRequestToMatrixReg val FIFOData = WireInit((VecInit(Seq.fill(MAX_Fill_Times)(0.U((8*CMatrixRegEntryByteSize).W))))) + val fillSlot = MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex) + val currentIsTail = MReg_Fill_Table_IsTail(CurrentFIFOIndex) + val fullByteMask = Fill(CMatrixRegEntryByteSize, true.B) + val tailByteMaskVec = Wire(Vec(MAX_Fill_Times, UInt(CMatrixRegEntryByteSize.W))) + for (j <- 0 until MAX_Fill_Times) { + val high = (j + 1) * CMatrixRegEntryByteSize - 1 + val low = j * CMatrixRegEntryByteSize + tailByteMaskVec(j) := tailTaskMask(high, low) + } FIFOData := MReg_Fill_Table(CurrentFIFOIndex).asTypeOf(FIFOData) - MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + (MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.BankAddr(i).bits := MReg_Fill_Table_MReg_Addr(CurrentFIFOIndex) + fillSlot MatrixRegWriteRequest.BankAddr(i).valid := true.B - MatrixRegWriteRequest.Data(i).bits := FIFOData(MAX_Fill_Times.U - MReg_Fill_Table_Time(CurrentFIFOIndex)) + MatrixRegWriteRequest.Data(i).bits := FIFOData(fillSlot) MatrixRegWriteRequest.Data(i).valid := true.B + MatrixRegWriteRequest.ByteMask(i).bits := Mux(currentIsTail, tailByteMaskVec(fillSlot), fullByteMask) + MatrixRegWriteRequest.ByteMask(i).valid := true.B MReg_Fill_Table_Time(CurrentFIFOIndex) := MReg_Fill_Table_Time(CurrentFIFOIndex) - 1.U when(MReg_Fill_Table_Time(CurrentFIFOIndex) === 1.U){ @@ -860,6 +889,7 @@ class CMemoryLoader(implicit p: Parameters) extends CuteModule{ WriteRequest.bits.RequestConherent := IsStoreConherent WriteRequest.bits.RequestSourceID := io.StoreLocalMMUIO.ConherentRequsetSourceID.bits WriteRequest.bits.RequestType_isWrite := true.B + WriteRequest.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) WriteRequest.bits.RequestData := Request_Data.asUInt WriteRequest.valid := true.B //只有fire了才能继续 diff --git a/src/main/scala/CUTE2YGJK.scala b/src/main/scala/CUTE2YGJK.scala index be36652..e2e241e 100644 --- a/src/main/scala/CUTE2YGJK.scala +++ b/src/main/scala/CUTE2YGJK.scala @@ -48,6 +48,7 @@ class CUTE2TLImp(outer: Cute2TL) extends LazyModuleImp(outer) with CUTEImplParam }) val data = io.mmu.Request.bits.RequestData + val mask = io.mmu.Request.bits.RequestMask val busy = RegInit(VecInit(Seq.fill(LLCSourceMaxNum)(false.B))) val id = WireInit(0.U(LLCSourceMaxNumBitSize.W)) @@ -137,7 +138,7 @@ class CUTE2TLImp(outer: Cute2TL) extends LazyModuleImp(outer) with CUTEImplParam tl_out.a.valid := io.mmu.Request.valid && !is_full tl_out.a.bits := Mux1H(Seq( (io.mmu.Request.bits.RequestType_isWrite === 0.U) -> edge.Get(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U)._2, - (io.mmu.Request.bits.RequestType_isWrite === 1.U) -> edge.Put(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U, data)._2 + (io.mmu.Request.bits.RequestType_isWrite === 1.U) -> edge.Put(id, io.mmu.Request.bits.RequestPhysicalAddr, log2Ceil(outsideDataWidthByte).U, data, mask)._2 )) // Assign MatrixKey to cooperate with HBL2. diff --git a/src/main/scala/CUTEParameters.scala b/src/main/scala/CUTEParameters.scala index 5cbb610..d299bfe 100644 --- a/src/main/scala/CUTEParameters.scala +++ b/src/main/scala/CUTEParameters.scala @@ -597,6 +597,7 @@ case class CuteParams( def DiffAmuFinishWordsPerBank = Math.max(ABMatrixRegEntryBitSize, CMatrixRegEntryBitSize) / 64 def ABMatrixRegNBanks = Matrix_MN //note that this is strongly tied to Matrix_MN and is generally an integer multiple of Matrix_MN def CMatrixRegNBanks = Matrix_MN //convenient for reorder + def Trans_Load_Size = outsideDataWidthByte / ABMatrixRegNBanks def ABMatrixReg_Total_Bandwidth = ABMatrixRegNBanks * ABMatrixRegEntryByteSize //total bandwidth of ABMatrixReg def CMatrixReg_Total_Bandwidth = CMatrixRegNBanks * CMatrixRegEntryByteSize //total bandwidth of CMatrixReg def ABMatrixReg_Total_Bandwidth_Bit = ABMatrixRegNBanks * ABMatrixRegEntryByteSize * 8 //total bandwidth of ABMatrixReg @@ -616,6 +617,7 @@ case class CuteParams( // require(ReduceGroupSize == 2, "ReduceGroupSize must be 2, Wait for update") require(outsideDataWidthByte <= Tensor_K, "outsideDataWidthByte must be less than or equal to Tensor_K, or a load will exceed the subtensor in micro load") + require(outsideDataWidthByte % ABMatrixRegNBanks == 0, "outsideDataWidthByte must be divisible by ABMatrixRegNBanks for transpose load") } @@ -715,6 +717,7 @@ trait CUTEImplParameters{ def DiffAmuFinishWordsPerBank = cuteParams.DiffAmuFinishWordsPerBank def ABMatrixRegNBanks = cuteParams.ABMatrixRegNBanks def CMatrixRegNBanks = cuteParams.CMatrixRegNBanks + def Trans_Load_Size = cuteParams.Trans_Load_Size def ABMatrixReg_Total_Bandwidth = cuteParams.ABMatrixReg_Total_Bandwidth def CMatrixReg_Total_Bandwidth = cuteParams.CMatrixReg_Total_Bandwidth def ABMatrixReg_Total_Bandwidth_Bit = cuteParams.ABMatrixReg_Total_Bandwidth_Bit @@ -946,6 +949,15 @@ class ApplicationTensor_A_Info()(implicit p: Parameters) extends CuteBundle{ // val BlockTensor_A_BaseVaddr = (UInt(MMUAddrWidth.W))//may be gone already val ApplicationTensor_A_Stride_M = (UInt(MMUAddrWidth.W))//address offset increment for the next M val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) +} + +class ApplicationScale_A_Info()(implicit p: Parameters) extends CuteBundle{ + val ApplicationScale_A_BaseVaddr = (UInt(MMUAddrWidth.W)) + val BlockScale_A_BaseVaddr = (UInt(MMUAddrWidth.W)) // main active field + val computeType = (UInt(MteComputeType.ComputeTypeBitWidth.W)) } class ApplicationScale_A_Info()(implicit p: Parameters) extends CuteBundle{ @@ -965,6 +977,7 @@ class AMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegId = UInt(ABMatrixRegIdWidth.W) val Conherent = (Bool()) //whether coherence is needed + val Is_Transpose = (Bool()) //whether transpose is needed val MicroTaskReady = Flipped(Bool())//can configure the next task val MicroTaskValid = (Bool()) //current task configuration is valid @@ -999,6 +1012,7 @@ class BMLMicroTaskConfigIO()(implicit p: Parameters) extends CuteBundle{ val MatrixRegId = UInt(ABMatrixRegIdWidth.W) val Conherent = (Bool()) //whether coherence is needed + val Is_Transpose = (Bool()) //whether transpose is needed val MicroTaskReady = Flipped(Bool())//can configure the next task val MicroTaskValid = (Bool()) //current task configuration is valid @@ -1029,19 +1043,20 @@ class ApplicationTensor_B_Info()(implicit p: Parameters) extends CuteBundle{ val BlockTensor_B_BaseVaddr = (UInt(MMUAddrWidth.W)) val ApplicationTensor_B_Stride_N = (UInt(MMUAddrWidth.W))//address offset increment for the next N val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val K_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } -class ApplicationScale_B_Info()(implicit p: Parameters) extends CuteBundle{ - val ApplicationScale_B_BaseVaddr = (UInt(MMUAddrWidth.W)) - val BlockScale_B_BaseVaddr = (UInt(MMUAddrWidth.W)) // main active field - val computeType = (UInt(MteComputeType.ComputeTypeBitWidth.W)) -} class ApplicationTensor_C_Info()(implicit p: Parameters) extends CuteBundle{ val ApplicationTensor_C_BaseVaddr = (UInt(MMUAddrWidth.W)) val BlockTensor_C_BaseVaddr = (UInt(MMUAddrWidth.W)) val ApplicationTensor_C_Stride_M = (UInt(MMUAddrWidth.W))//address offset increment for the next M val dataType = (UInt(ElementDataType.DataTypeBitWidth.W)) + val HasTail = Bool() + val TailByteMask = UInt(log2Ceil(outsideDataWidthByte + 1).W) + val N_Beat_Count = UInt(MatrixRegMaxTensorDimBitSize.W) } class ApplicationTensor_D_Info()(implicit p: Parameters) extends CuteBundle{ @@ -1133,6 +1148,16 @@ class ABMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val BankAddr = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(log2Ceil(ABMatrixRegBankNEntries).W)))) //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 val Data = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryBitSize.W)))) + val ByteMask = Flipped(Vec(ABMatrixRegNBanks, Valid(UInt(ABMatrixRegEntryByteSize.W)))) +} + +class ABScaleLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ + //bankaddr is the row-select signal for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is log2Ceil(AScratchpadBankNLines), and it is input data that requires handshaking + val BankAddr = Flipped(Valid(UInt(log2Ceil(ABScaleBankNEntries).W))) + //bankdata is the row data for each of the nbanks banks; it is a Vec with nbanks elements, each a UInt whose width is ReduceWidthByte*8 + val Data = Flipped(Valid(Vec(ABScaleNSlices, UInt((ScaleWidth * ReduceGroupSize).W)))) + //chosen selects the ScratchPad; it is a Bool. We use double buffering, selecting one for output and one for loading + // val Chosen = Input(Bool()) } class ABScaleLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ @@ -1165,6 +1190,7 @@ class CMemoryLoaderMatrixRegIO(implicit p: Parameters) extends CuteBundle{ val WriteRequestToMatrixReg = (new Bundle{ val BankAddr = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(log2Ceil(CMatrixRegBankNEntries).W))))) val Data = Flipped(Vec(CMatrixRegNBanks, (Valid(UInt(CMatrixRegEntryBitSize.W))))) + val ByteMask = Flipped(Vec(CMatrixRegNBanks, Valid(UInt(CMatrixRegEntryByteSize.W)))) }) val LoadReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) val StoreReadWriteRequest = Input(UInt((MatrixRegTaskType.TaskTypeBitWidth).W)) @@ -1183,6 +1209,7 @@ class LocalMMUIO(implicit p: Parameters) extends CuteBundle{ val RequestData = UInt(MMUDataWidth.W) val RequestSourceID = UInt(SoureceMaxNumBitSize.W) val RequestType_isWrite = Bool() + val RequestMask = UInt(MMUMaskWidth.W) //MMU byte mask })) //transaction ID of the TL link to which the read request is dispatched val ConherentRequsetSourceID = Valid(UInt(LLCSourceMaxNumBitSize.W)) @@ -1350,6 +1377,13 @@ class FReducePEDataType { case object ElementDataType extends Field[UInt]{ val DataTypeBitWidth = 4 val DataTypeUndef = 0.U(DataTypeBitWidth.W) + val DataTypeI8I8I32 = MteComputeType.I8I8I32 + val DataTypeI8U8I32 = MteComputeType.I8U8I32 + val DataTypeU8I8I32 = MteComputeType.U8I8I32 + val DataTypeU8U8I32 = MteComputeType.U8U8I32 + val DataTypeF16F16F32 = MteComputeType.F16F16F32 + val DataTypeBF16BF16F32 = MteComputeType.BF16BF16F32 + val DataTypeTF32TF32F32 = MteComputeType.TF32TF32F32 val DataTypeWidth32 = 4.U(DataTypeBitWidth.W) val DataTypeWidth16 = 2.U(DataTypeBitWidth.W) val DataTypeWidth8 = 1.U(DataTypeBitWidth.W) diff --git a/src/main/scala/CUTETOP.scala b/src/main/scala/CUTETOP.scala index 654c966..a0ead15 100644 --- a/src/main/scala/CUTETOP.scala +++ b/src/main/scala/CUTETOP.scala @@ -185,6 +185,17 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ //MemoryLoader的请求 ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.BankAddr := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.BankAddr) ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.Data := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.Data) + ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask := 0.U.asTypeOf(ABMatrixRegs(i).io.MatrixRegIO.FromMemoryLoader.ByteMask) + } + + if (cuteMatrixExtension.enableScalingFactor) { + // AB Scale Regs + (ASMRegs.get ++ BSMRegs.get).foreach { reg => + reg.io.FromScaleController.BankAddr.valid := false.B + reg.io.FromScaleController.BankAddr.bits := 0.U.asTypeOf(reg.io.FromScaleController.BankAddr.bits) + reg.io.FromScaleLoader.BankAddr := 0.U.asTypeOf(reg.io.FromScaleLoader.BankAddr) + reg.io.FromScaleLoader.Data := 0.U.asTypeOf(reg.io.FromScaleLoader.Data) + } } if (cuteMatrixExtension.enableScalingFactor) { @@ -269,6 +280,8 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.BankAddr(b).bits := DontCare dest.Data(b).valid := false.B dest.Data(b).bits := DontCare + dest.ByteMask(b).valid := false.B + dest.ByteMask(b).bits := DontCare } } @@ -278,6 +291,8 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.BankAddr(b).bits := src.BankAddr(b).bits dest.Data(b).valid := src.Data(b).valid dest.Data(b).bits := src.Data(b).bits + dest.ByteMask(b).valid := src.ByteMask(b).valid + dest.ByteMask(b).bits := src.ByteMask(b).bits } } @@ -362,9 +377,11 @@ class CUTEV2Top()(implicit p: Parameters) extends CuteModule{ dest.ReadRequestToMatrixReg.BankAddr(bank).valid := storeSel && CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).valid dest.WriteRequestToMatrixReg.BankAddr(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).valid dest.WriteRequestToMatrixReg.Data(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).valid + dest.WriteRequestToMatrixReg.ByteMask(bank).valid := loadSel && CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).valid dest.ReadRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.ReadRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.BankAddr(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.BankAddr(bank).bits dest.WriteRequestToMatrixReg.Data(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.Data(bank).bits + dest.WriteRequestToMatrixReg.ByteMask(bank).bits := CML.io.ToMatrixRegIO.WriteRequestToMatrixReg.ByteMask(bank).bits } val loadReqForThisReg = Mux(loadSel, CML.io.ToMatrixRegIO.LoadReadWriteRequest, 0.U) diff --git a/src/main/scala/LocalMMU.scala b/src/main/scala/LocalMMU.scala index abeb35c..222c135 100644 --- a/src/main/scala/LocalMMU.scala +++ b/src/main/scala/LocalMMU.scala @@ -84,6 +84,7 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.LastLevelCacheTLIO.Request.valid := false.B io.LastLevelCacheTLIO.Request.bits := DontCare io.LastLevelCacheTLIO.Response.ready := false.B + val selectedRequestMask = WireDefault(Fill(MMUMaskWidth, 1.U(1.W))) when(io.LastLevelCacheTLIO.ConherentRequsetSourceID.valid && HasRequest) { @@ -93,6 +94,7 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.ALocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.ALocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := false.B + selectedRequestMask := io.ALocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.AFirst io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B } @@ -111,6 +113,7 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.BLocalMMUIO.ConherentRequsetSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.BLocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := false.B + selectedRequestMask := io.BLocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.BFirst io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := false.B } @@ -130,6 +133,7 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.CLoadLocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestData := io.CLoadLocalMMUIO.Request.bits.RequestData io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := io.CLoadLocalMMUIO.Request.bits.RequestType_isWrite + selectedRequestMask := io.CLoadLocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.CLoadFirst io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := true.B } @@ -139,12 +143,13 @@ class LocalMMU()(implicit p: Parameters) extends CuteModule{ io.LastLevelCacheTLIO.Request.bits.RequestPhysicalAddr := io.CStoreLocalMMUIO.Request.bits.RequestVirtualAddr io.LastLevelCacheTLIO.Request.bits.RequestData := io.CStoreLocalMMUIO.Request.bits.RequestData io.LastLevelCacheTLIO.Request.bits.RequestType_isWrite := io.CStoreLocalMMUIO.Request.bits.RequestType_isWrite + selectedRequestMask := io.CStoreLocalMMUIO.Request.bits.RequestMask sourceid2port(io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits) := LocalMMUTaskType.CStoreFirst io.LastLevelCacheTLIO.Request.bits.MatrixIsAcc := true.B } } - io.LastLevelCacheTLIO.Request.bits.RequestMask := Fill(MMUMaskWidth, 1.U(1.W)) + io.LastLevelCacheTLIO.Request.bits.RequestMask := selectedRequestMask io.LastLevelCacheTLIO.Request.bits.RequestConherent := true.B io.LastLevelCacheTLIO.Request.bits.RequestSourceID := io.LastLevelCacheTLIO.ConherentRequsetSourceID.bits io.LastLevelCacheTLIO.Request.valid := true.B diff --git a/src/main/scala/TaskController.scala b/src/main/scala/TaskController.scala index b6fc9de..1887cc3 100644 --- a/src/main/scala/TaskController.scala +++ b/src/main/scala/TaskController.scala @@ -138,6 +138,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.MatrixRegTensor_M := 0.U io.AML_MicroTask_Config.MatrixRegTensor_K := 0.U io.AML_MicroTask_Config.Conherent := false.B + io.AML_MicroTask_Config.Is_Transpose := false.B io.AML_MicroTask_Config.MatrixRegId := 0.U io.AML_MicroTask_Config.MicroTaskValid := false.B io.AML_MicroTask_Config.MicroTaskEndReady := false.B @@ -159,6 +160,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.BML_MicroTask_Config.MatrixRegTensor_N := 0.U io.BML_MicroTask_Config.MatrixRegTensor_K := 0.U io.BML_MicroTask_Config.Conherent := false.B + io.BML_MicroTask_Config.Is_Transpose := false.B io.BML_MicroTask_Config.MatrixRegId := 0.U io.BML_MicroTask_Config.MicroTaskValid := false.B io.BML_MicroTask_Config.MicroTaskEndReady := false.B @@ -302,6 +304,11 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { decEntryEnq.writeValid(i) := false.B } + + val enqMma = decodeMma(amuCtrlBits) + val enqLsu = decodeLsu(amuCtrlBits) + val enqArith = decodeArith(amuCtrlBits) + val enqMma = decodeMma(amuCtrlBits) val enqLsu = decodeLsu(amuCtrlBits) val enqArith = decodeArith(amuCtrlBits) @@ -602,6 +609,33 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { )) } + private val loadByteCountBits = MatrixRegMaxTensorDimBitSize + 2 + + private def loadByteCount(dim: UInt, widths: UInt): UInt = { + val dimWide = dim.pad(loadByteCountBits) + val byteCount = WireDefault((dimWide << 2)(loadByteCountBits - 1, 0)) + switch(widths) { + is(Bundles.MSew.e8) { byteCount := dimWide } + is(Bundles.MSew.e16) { byteCount := (dimWide << 1)(loadByteCountBits - 1, 0) } + is(Bundles.MSew.e32) { byteCount := (dimWide << 2)(loadByteCountBits - 1, 0) } + is(Bundles.MSew.e4) { byteCount := dimWide >> 1 } + } + byteCount + } + + private def loadBeatCount(dim: UInt, widths: UInt): UInt = { + val beatCount = (loadByteCount(dim, widths) + (outsideDataWidthByte - 1).U) >> log2Ceil(outsideDataWidthByte) + beatCount.pad(MatrixRegMaxTensorDimBitSize)(MatrixRegMaxTensorDimBitSize - 1, 0) + } + + private def loadHasTail(dim: UInt, widths: UInt): Bool = { + loadByteCount(dim, widths)(log2Ceil(outsideDataWidthByte) - 1, 0).orR + } + + private def loadTailByteMask(dim: UInt, widths: UInt): UInt = { + Cat(0.U((log2Ceil(outsideDataWidthByte + 1) - log2Ceil(outsideDataWidthByte)).W), loadByteCount(dim, widths)(log2Ceil(outsideDataWidthByte) - 1, 0)) + } + private def decodeMmaComputeType(mma: AmuMmaIO): UInt = { val computeType = WireInit(MteComputeType.ComputeTypeUndef) when(mma.isfp) { @@ -642,18 +676,24 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { switch(issueSlot.opKind) { is(TaskCtrlOpKind.LoadA) { val regIdx = issueLsu.ms(1, 0) - val kVal = computeKFromMsew(issueLsu.column, issueLsu.widths) / ReduceWidthByte.U + val matrixDim = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val reduceDim = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val kVal = loadBeatCount(reduceDim, issueLsu.widths) io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_BaseVaddr := issueLsu.baseAddr io.AML_MicroTask_Config.ApplicationTensor_A.ApplicationTensor_A_Stride_M := issueLsu.stride io.AML_MicroTask_Config.ApplicationTensor_A.dataType := loadDataType(issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.HasTail := loadHasTail(reduceDim, issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.TailByteMask := loadTailByteMask(reduceDim, issueLsu.widths) + io.AML_MicroTask_Config.ApplicationTensor_A.K_Beat_Count := kVal io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B io.AML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.AML_MicroTask_Config.MatrixRegTensor_M := issueLsu.row + io.AML_MicroTask_Config.MatrixRegTensor_M := matrixDim io.AML_MicroTask_Config.MatrixRegTensor_K := kVal io.AML_MicroTask_Config.MatrixRegId := regIdx io.AML_MicroTask_Config.Conherent := true.B + io.AML_MicroTask_Config.Is_Transpose := issueLsu.transpose io.AML_MicroTask_Config.MicroTaskValid := true.B if (EnableDifftest) { io.AML_MicroTask_Config.pc.get := issueCtrl.pc.get @@ -687,16 +727,22 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { is(TaskCtrlOpKind.LoadB) { val regIdx = issueLsu.ms(1, 0) - val kVal = computeKFromMsew(issueLsu.row, issueLsu.widths) / ReduceWidthByte.U + val matrixDim = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val reduceDim = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val kVal = loadBeatCount(reduceDim, issueLsu.widths) io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_BaseVaddr := issueLsu.baseAddr io.BML_MicroTask_Config.ApplicationTensor_B.ApplicationTensor_B_Stride_N := issueLsu.stride io.BML_MicroTask_Config.ApplicationTensor_B.BlockTensor_B_BaseVaddr := issueLsu.baseAddr io.BML_MicroTask_Config.ApplicationTensor_B.dataType := loadDataType(issueLsu.widths) - io.BML_MicroTask_Config.MatrixRegTensor_N := issueLsu.column + io.BML_MicroTask_Config.ApplicationTensor_B.HasTail := loadHasTail(reduceDim, issueLsu.widths) + io.BML_MicroTask_Config.ApplicationTensor_B.TailByteMask := loadTailByteMask(reduceDim, issueLsu.widths) + io.BML_MicroTask_Config.ApplicationTensor_B.K_Beat_Count := kVal + io.BML_MicroTask_Config.MatrixRegTensor_N := matrixDim io.BML_MicroTask_Config.MatrixRegTensor_K := kVal io.BML_MicroTask_Config.MatrixRegId := regIdx io.BML_MicroTask_Config.Conherent := true.B + io.BML_MicroTask_Config.Is_Transpose := issueLsu.transpose io.BML_MicroTask_Config.MicroTaskValid := true.B if (EnableDifftest) { io.BML_MicroTask_Config.pc.get := issueCtrl.pc.get @@ -730,17 +776,23 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { is(TaskCtrlOpKind.LoadC) { val regIdx = issueLsu.ms(1, 0) + val matrixDimM = Mux(issueLsu.transpose, issueLsu.column, issueLsu.row) + val matrixDimN = Mux(issueLsu.transpose, issueLsu.row, issueLsu.column) + val nVal = loadBeatCount(matrixDimN, issueLsu.widths) io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_BaseVaddr := issueLsu.baseAddr io.CML_MicroTask_Config.ApplicationTensor_C.ApplicationTensor_C_Stride_M := issueLsu.stride io.CML_MicroTask_Config.ApplicationTensor_C.BlockTensor_C_BaseVaddr := issueLsu.baseAddr io.CML_MicroTask_Config.ApplicationTensor_C.dataType := loadDataType(issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.HasTail := loadHasTail(matrixDimN, issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.TailByteMask := loadTailByteMask(matrixDimN, issueLsu.widths) + io.CML_MicroTask_Config.ApplicationTensor_C.N_Beat_Count := nVal io.CML_MicroTask_Config.Conherent := true.B io.CML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := true.B io.CML_MicroTask_Config.LoadTaskInfo.Is_ZeroLoad := false.B io.CML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B - io.CML_MicroTask_Config.MatrixRegTensor_M := issueLsu.row - io.CML_MicroTask_Config.MatrixRegTensor_N := issueLsu.column + io.CML_MicroTask_Config.MatrixRegTensor_M := matrixDimM + io.CML_MicroTask_Config.MatrixRegTensor_N := matrixDimN io.CML_MicroTask_Config.MatrixRegId := regIdx io.CML_MicroTask_Config.Is_Transpose := issueLsu.transpose io.CML_MicroTask_Config.LoadMicroTaskValid := true.B @@ -831,6 +883,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { io.AML_MicroTask_Config.LoadTaskInfo.Is_FullLoad := false.B io.AML_MicroTask_Config.LoadTaskInfo.Is_RepeatRowLoad := false.B io.AML_MicroTask_Config.Conherent := true.B + io.AML_MicroTask_Config.Is_Transpose := false.B if (EnableDifftest) { io.AML_MicroTask_Config.pc.get := issueCtrl.pc.get io.AML_MicroTask_Config.coreid.get := issueCtrl.coreid.get @@ -1390,6 +1443,7 @@ class TaskController(implicit p: Parameters) extends BaseTaskController { releaseFinish.pc := Mux(issueFire && issueSlot.opKind === TaskCtrlOpKind.Release, releaseIssueOwnerEntry.ctrl.pc.get, 0.U) releaseFinish.bankValid.foreach(_ := false.B) releaseFinish.bankAddr.foreach(_ := 0.U) + releaseFinish.bankMask.foreach(_ := 0.U) releaseFinish.data.foreach(_ := 0.U) releaseFinish.finish := issueFire && issueSlot.opKind === TaskCtrlOpKind.Release