diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 830f70b07f5b..83ce76c24856 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -970,6 +970,13 @@ using PendingSolutions = DenseMap; /// port index. using PendingExports = llvm::MapVector; +/// A map from domain values to their origin (instance name + port name). +/// Used to generate better names for inferred domain ports. +using DomainOriginTable = DenseMap; + +/// A map from terms to their origin. Used to track origins through unification. +using TermOriginTable = DenseMap; + namespace { struct PendingUpdates { PortInsertions insertions; @@ -981,6 +988,8 @@ struct PendingUpdates { /// If `var` is not solved, solve it by recording a pending input port at /// the indicated insertion point. static void ensureSolved(const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending) { if (pending.solutions.contains(var)) @@ -990,7 +999,21 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns, auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); - auto portName = StringAttr::get(context, ns.newName(domainName.getValue())); + // Try to find an origin name for this variable. Check both the term origin + // table and the value origin table. + StringRef baseName = domainName.getValue(); + + // First, check if this term itself has an origin. + Term *root = find(var); + if (auto it = termOrigins.find(root); it != termOrigins.end()) { + baseName = it->second; + } else if (auto *val = dyn_cast(root)) { + // If the term is a value, check the value origin table. + if (auto it = valueOrigins.find(val->value); it != valueOrigins.end()) + baseName = it->second; + } + + auto portName = StringAttr::get(context, ns.newName(baseName)); auto portType = DomainType::getFromDomainOp(domainDecl); auto portDirection = Direction::In; auto portSym = StringAttr(); @@ -1011,6 +1034,7 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns, // an /// output port. static void ensureExported(const DomainInfo &info, Namespace &ns, + const DomainOriginTable &origins, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending) { @@ -1025,7 +1049,12 @@ static void ensureExported(const DomainInfo &info, Namespace &ns, auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); - auto portName = StringAttr::get(context, ns.newName(domainName.getValue())); + // Use origin name if available, otherwise fall back to domain name. + StringRef baseName = domainName.getValue(); + if (auto it = origins.find(value); it != origins.end()) + baseName = it->second; + + auto portName = StringAttr::get(context, ns.newName(baseName)); auto portType = DomainType::getFromDomainOp(domainDecl); auto portDirection = Direction::Out; auto portSym = StringAttr(); @@ -1039,44 +1068,49 @@ static void ensureExported(const DomainInfo &info, Namespace &ns, pending.insertions.push_back({ip, portInfo}); } -static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, - Namespace &ns, - PendingUpdates &pending, - DomainTypeID typeID, size_t ip, - LocationAttr loc, Term *term, - const ExportTable &exports) { +static void getUpdatesForDomainAssociationOfPort( + const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins, + PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, + Term *term, const ExportTable &exports) { if (auto *var = dyn_cast(term)) { - ensureSolved(info, ns, typeID, ip, loc, var, pending); + ensureSolved(info, ns, valueOrigins, termOrigins, typeID, ip, loc, var, + pending); return; } if (auto *val = dyn_cast(term)) { - ensureExported(info, ns, exports, typeID, ip, loc, val, pending); + ensureExported(info, ns, valueOrigins, exports, typeID, ip, loc, val, + pending); return; } llvm_unreachable("invalid domain association"); } static void getUpdatesForDomainAssociationOfPort( - const DomainInfo &info, Namespace &ns, const ExportTable &exports, - size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) { + const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins, + const ExportTable &exports, size_t ip, LocationAttr loc, RowTerm *row, + PendingUpdates &pending) { for (auto [index, term] : llvm::enumerate(row->elements)) - getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index}, - ip, loc, find(term), exports); + getUpdatesForDomainAssociationOfPort(info, ns, valueOrigins, termOrigins, + pending, DomainTypeID{index}, ip, loc, + find(term), exports); } -static void getUpdatesForModulePorts(const DomainInfo &info, - TermAllocator &allocator, - const ExportTable &exports, - DomainTable &table, Namespace &ns, - FModuleOp moduleOp, - PendingUpdates &pending) { +static void +getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, + const ExportTable &exports, DomainTable &table, + Namespace &ns, const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins, FModuleOp moduleOp, + PendingUpdates &pending) { for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) { auto port = moduleOp.getArgument(i); auto type = port.getType(); if (!isa(type)) continue; getUpdatesForDomainAssociationOfPort( - info, ns, exports, i, moduleOp.getPortLocation(i), + info, ns, valueOrigins, termOrigins, exports, i, + moduleOp.getPortLocation(i), getDomainAssociationAsRow(info, allocator, table, port), pending); } } @@ -1084,12 +1118,16 @@ static void getUpdatesForModulePorts(const DomainInfo &info, static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, - FModuleOp mod, PendingUpdates &pending) { + FModuleOp mod, PendingUpdates &pending, + const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins) { Namespace ns; auto names = mod.getPortNamesAttr(); for (auto name : names.getAsRange()) ns.add(name); - getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending); + + getUpdatesForModulePorts(info, allocator, exports, table, ns, valueOrigins, + termOrigins, mod, pending); } static void applyUpdatesToModule(const DomainInfo &info, @@ -1282,8 +1320,7 @@ static LogicalResult updateInstance(const DomainInfo &info, return success(); } -/// After updating the port domain associations, walk the body of the moduleOp -/// to fix up any child instance modules. +/// Update the module body by creating domain.define operations for instances. static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp) { @@ -1299,13 +1336,56 @@ static LogicalResult updateModuleBody(const DomainInfo &info, return failure(result.wasInterrupted()); } +/// Build domain origin tables by walking instances. This is done in a separate +/// function so it can be called before getUpdatesForModule. +static void buildDomainOriginTables(const DomainTable &table, FModuleOp mod, + DomainOriginTable &valueOrigins, + TermOriginTable &termOrigins) { + mod.getBodyBlock()->walk([&](FInstanceLike op) { + auto instanceName = op.getInstanceName(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + auto port = dyn_cast(op->getResult(i)); + if (!port) + continue; + + auto portName = op.getPortNameAttr(i); + std::string originName = (instanceName + "_" + portName.getValue()).str(); + + // Track this origin for the domain value (use first occurrence). + if (!valueOrigins.count(port)) + valueOrigins[port] = originName; + + // Also track the origin for the term associated with this port. + // This is crucial for input ports where the term might be a variable. + if (auto *term = table.getOptTermForDomain(port)) { + Term *root = find(term); + if (!termOrigins.count(root)) + termOrigins[root] = originName; + + // If the term is a ValueTerm, also track the value. + if (auto *val = dyn_cast(root)) { + if (!valueOrigins.count(val->value)) + valueOrigins[val->value] = originName; + } + } + } + }); +} + /// Write the domain associations recorded in the domain table back to the IR. static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op) { + // Build domain origin tables by walking instances. This will be used for + // generating better port names. + DomainOriginTable valueOrigins; + TermOriginTable termOrigins; + buildDomainOriginTables(table, op, valueOrigins, termOrigins); + auto exports = initializeExportTable(table, op); PendingUpdates pending; - getUpdatesForModule(info, allocator, exports, table, op, pending); + getUpdatesForModule(info, allocator, exports, table, op, pending, + valueOrigins, termOrigins); applyUpdatesToModule(info, allocator, exports, table, op, pending); // Update the domain info for the moduleOp's ports. diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir index 9bf42e716902..a56ac0304ec3 100644 --- a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir +++ b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir @@ -168,8 +168,8 @@ firrtl.circuit "ExportDomain" { ) firrtl.module @ExportDomain( - // CHECK: out %ClockDomain: !firrtl.domain<@ClockDomain()> - // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] + // CHECK: out %foo_A: !firrtl.domain<@ClockDomain()> + // CHECK: out %o: !firrtl.uint<1> domains [%foo_A] out %o: !firrtl.uint<1> ) { %foo_A, %foo_o = firrtl.instance foo @Foo( @@ -177,7 +177,7 @@ firrtl.circuit "ExportDomain" { out o: !firrtl.uint<1> domains [A] ) firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> - // CHECK: firrtl.domain.define %ClockDomain, %foo_A : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %foo_A, %foo_A_0 : !firrtl.domain<@ClockDomain()> } } @@ -228,10 +228,10 @@ firrtl.circuit "InstanceUpdate" { firrtl.module @Foo(in %i : !firrtl.uint<1>) {} - // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.module @InstanceUpdate(in %foo_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%foo_ClockDomain]) { + // CHECK: %foo_ClockDomain_0, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> - // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %foo_ClockDomain_0, %foo_ClockDomain : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) @@ -252,10 +252,10 @@ firrtl.circuit "InstanceChoiceUpdate" { firrtl.module @Bar(in %i : !firrtl.uint<1>) {} firrtl.module @Baz(in %i : !firrtl.uint<1>) {} - // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.module @InstanceChoiceUpdate(in %inst_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%inst_ClockDomain]) { + // CHECK: %inst_ClockDomain_0, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> - // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %inst_ClockDomain_0, %inst_ClockDomain : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) diff --git a/test/Dialect/FIRRTL/infer-domains-infer.mlir b/test/Dialect/FIRRTL/infer-domains-infer.mlir index f97b33abd0f3..271d9fd64dbd 100644 --- a/test/Dialect/FIRRTL/infer-domains-infer.mlir +++ b/test/Dialect/FIRRTL/infer-domains-infer.mlir @@ -5,9 +5,9 @@ firrtl.circuit "InferOutputDomain" { firrtl.extmodule @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) - // CHECK: firrtl.module private @Bar(out %ClockDomain: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) - // CHECK: firrtl.domain.define %ClockDomain, %foo_D : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.module private @Bar(out %foo_D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%foo_D]) { + // CHECK: %foo_D_0, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) + // CHECK: firrtl.domain.define %foo_D, %foo_D_0 : !firrtl.domain<@ClockDomain()> // CHECK: firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> // CHECK: } firrtl.module private @Bar(out %x : !firrtl.uint<1>) { @@ -16,9 +16,9 @@ firrtl.circuit "InferOutputDomain" { } // CHECK: firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) { - // CHECK: %bar_ClockDomain, %bar_x = firrtl.instance bar @Bar(out ClockDomain: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: %bar_foo_D, %bar_x = firrtl.instance bar @Bar(out foo_D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [foo_D]) // CHECK: firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> - // CHECK: firrtl.domain.define %D, %bar_ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %D, %bar_foo_D : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) { %bar_x = firrtl.instance bar @Bar(out x : !firrtl.uint<1>)