From 2e872614fc950c2be7942e7f224be7cb8280e722 Mon Sep 17 00:00:00 2001 From: Schuyler Eldridge Date: Wed, 25 Mar 2026 16:26:00 -0400 Subject: [PATCH] [FIRRTL] Improve domain port naming in InferDomains pass Inferred domain ports are now named based on their origin (instance name and port name) rather than just the domain type name. This makes it easier to trace where domains come from in the design. For example, a domain port that originates from instance 'bar' port 'A' is now named 'bar_A' instead of 'ClockDomain' or 'ClockDomain_0'. This change tracks domain origins through the unification process and uses them when generating port names, falling back to the domain type name when no origin is available. AI-assisted-by: Augment (Claude Sonnet 4.5) Signed-off-by: Schuyler Eldridge --- .../FIRRTL/Transforms/InferDomains.cpp | 132 ++++++++++++++---- .../FIRRTL/infer-domains-infer-all.mlir | 18 +-- test/Dialect/FIRRTL/infer-domains-infer.mlir | 10 +- 3 files changed, 120 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 830f70b07f5b..83ce76c24856 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -970,6 +970,13 @@ using PendingSolutions = DenseMap; /// port index. using PendingExports = llvm::MapVector; +/// A map from domain values to their origin (instance name + port name). +/// Used to generate better names for inferred domain ports. +using DomainOriginTable = DenseMap; + +/// A map from terms to their origin. Used to track origins through unification. +using TermOriginTable = DenseMap; + namespace { struct PendingUpdates { PortInsertions insertions; @@ -981,6 +988,8 @@ struct PendingUpdates { /// If `var` is not solved, solve it by recording a pending input port at /// the indicated insertion point. static void ensureSolved(const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending) { if (pending.solutions.contains(var)) @@ -990,7 +999,21 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns, auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); - auto portName = StringAttr::get(context, ns.newName(domainName.getValue())); + // Try to find an origin name for this variable. Check both the term origin + // table and the value origin table. + StringRef baseName = domainName.getValue(); + + // First, check if this term itself has an origin. + Term *root = find(var); + if (auto it = termOrigins.find(root); it != termOrigins.end()) { + baseName = it->second; + } else if (auto *val = dyn_cast(root)) { + // If the term is a value, check the value origin table. + if (auto it = valueOrigins.find(val->value); it != valueOrigins.end()) + baseName = it->second; + } + + auto portName = StringAttr::get(context, ns.newName(baseName)); auto portType = DomainType::getFromDomainOp(domainDecl); auto portDirection = Direction::In; auto portSym = StringAttr(); @@ -1011,6 +1034,7 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns, // an /// output port. static void ensureExported(const DomainInfo &info, Namespace &ns, + const DomainOriginTable &origins, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending) { @@ -1025,7 +1049,12 @@ static void ensureExported(const DomainInfo &info, Namespace &ns, auto domainDecl = info.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); - auto portName = StringAttr::get(context, ns.newName(domainName.getValue())); + // Use origin name if available, otherwise fall back to domain name. + StringRef baseName = domainName.getValue(); + if (auto it = origins.find(value); it != origins.end()) + baseName = it->second; + + auto portName = StringAttr::get(context, ns.newName(baseName)); auto portType = DomainType::getFromDomainOp(domainDecl); auto portDirection = Direction::Out; auto portSym = StringAttr(); @@ -1039,44 +1068,49 @@ static void ensureExported(const DomainInfo &info, Namespace &ns, pending.insertions.push_back({ip, portInfo}); } -static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, - Namespace &ns, - PendingUpdates &pending, - DomainTypeID typeID, size_t ip, - LocationAttr loc, Term *term, - const ExportTable &exports) { +static void getUpdatesForDomainAssociationOfPort( + const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins, + PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, + Term *term, const ExportTable &exports) { if (auto *var = dyn_cast(term)) { - ensureSolved(info, ns, typeID, ip, loc, var, pending); + ensureSolved(info, ns, valueOrigins, termOrigins, typeID, ip, loc, var, + pending); return; } if (auto *val = dyn_cast(term)) { - ensureExported(info, ns, exports, typeID, ip, loc, val, pending); + ensureExported(info, ns, valueOrigins, exports, typeID, ip, loc, val, + pending); return; } llvm_unreachable("invalid domain association"); } static void getUpdatesForDomainAssociationOfPort( - const DomainInfo &info, Namespace &ns, const ExportTable &exports, - size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) { + const DomainInfo &info, Namespace &ns, + const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins, + const ExportTable &exports, size_t ip, LocationAttr loc, RowTerm *row, + PendingUpdates &pending) { for (auto [index, term] : llvm::enumerate(row->elements)) - getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index}, - ip, loc, find(term), exports); + getUpdatesForDomainAssociationOfPort(info, ns, valueOrigins, termOrigins, + pending, DomainTypeID{index}, ip, loc, + find(term), exports); } -static void getUpdatesForModulePorts(const DomainInfo &info, - TermAllocator &allocator, - const ExportTable &exports, - DomainTable &table, Namespace &ns, - FModuleOp moduleOp, - PendingUpdates &pending) { +static void +getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, + const ExportTable &exports, DomainTable &table, + Namespace &ns, const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins, FModuleOp moduleOp, + PendingUpdates &pending) { for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) { auto port = moduleOp.getArgument(i); auto type = port.getType(); if (!isa(type)) continue; getUpdatesForDomainAssociationOfPort( - info, ns, exports, i, moduleOp.getPortLocation(i), + info, ns, valueOrigins, termOrigins, exports, i, + moduleOp.getPortLocation(i), getDomainAssociationAsRow(info, allocator, table, port), pending); } } @@ -1084,12 +1118,16 @@ static void getUpdatesForModulePorts(const DomainInfo &info, static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, - FModuleOp mod, PendingUpdates &pending) { + FModuleOp mod, PendingUpdates &pending, + const DomainOriginTable &valueOrigins, + const TermOriginTable &termOrigins) { Namespace ns; auto names = mod.getPortNamesAttr(); for (auto name : names.getAsRange()) ns.add(name); - getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending); + + getUpdatesForModulePorts(info, allocator, exports, table, ns, valueOrigins, + termOrigins, mod, pending); } static void applyUpdatesToModule(const DomainInfo &info, @@ -1282,8 +1320,7 @@ static LogicalResult updateInstance(const DomainInfo &info, return success(); } -/// After updating the port domain associations, walk the body of the moduleOp -/// to fix up any child instance modules. +/// Update the module body by creating domain.define operations for instances. static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp) { @@ -1299,13 +1336,56 @@ static LogicalResult updateModuleBody(const DomainInfo &info, return failure(result.wasInterrupted()); } +/// Build domain origin tables by walking instances. This is done in a separate +/// function so it can be called before getUpdatesForModule. +static void buildDomainOriginTables(const DomainTable &table, FModuleOp mod, + DomainOriginTable &valueOrigins, + TermOriginTable &termOrigins) { + mod.getBodyBlock()->walk([&](FInstanceLike op) { + auto instanceName = op.getInstanceName(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + auto port = dyn_cast(op->getResult(i)); + if (!port) + continue; + + auto portName = op.getPortNameAttr(i); + std::string originName = (instanceName + "_" + portName.getValue()).str(); + + // Track this origin for the domain value (use first occurrence). + if (!valueOrigins.count(port)) + valueOrigins[port] = originName; + + // Also track the origin for the term associated with this port. + // This is crucial for input ports where the term might be a variable. + if (auto *term = table.getOptTermForDomain(port)) { + Term *root = find(term); + if (!termOrigins.count(root)) + termOrigins[root] = originName; + + // If the term is a ValueTerm, also track the value. + if (auto *val = dyn_cast(root)) { + if (!valueOrigins.count(val->value)) + valueOrigins[val->value] = originName; + } + } + } + }); +} + /// Write the domain associations recorded in the domain table back to the IR. static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op) { + // Build domain origin tables by walking instances. This will be used for + // generating better port names. + DomainOriginTable valueOrigins; + TermOriginTable termOrigins; + buildDomainOriginTables(table, op, valueOrigins, termOrigins); + auto exports = initializeExportTable(table, op); PendingUpdates pending; - getUpdatesForModule(info, allocator, exports, table, op, pending); + getUpdatesForModule(info, allocator, exports, table, op, pending, + valueOrigins, termOrigins); applyUpdatesToModule(info, allocator, exports, table, op, pending); // Update the domain info for the moduleOp's ports. diff --git a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir index 9bf42e716902..a56ac0304ec3 100644 --- a/test/Dialect/FIRRTL/infer-domains-infer-all.mlir +++ b/test/Dialect/FIRRTL/infer-domains-infer-all.mlir @@ -168,8 +168,8 @@ firrtl.circuit "ExportDomain" { ) firrtl.module @ExportDomain( - // CHECK: out %ClockDomain: !firrtl.domain<@ClockDomain()> - // CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain] + // CHECK: out %foo_A: !firrtl.domain<@ClockDomain()> + // CHECK: out %o: !firrtl.uint<1> domains [%foo_A] out %o: !firrtl.uint<1> ) { %foo_A, %foo_o = firrtl.instance foo @Foo( @@ -177,7 +177,7 @@ firrtl.circuit "ExportDomain" { out o: !firrtl.uint<1> domains [A] ) firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1> - // CHECK: firrtl.domain.define %ClockDomain, %foo_A : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %foo_A, %foo_A_0 : !firrtl.domain<@ClockDomain()> } } @@ -228,10 +228,10 @@ firrtl.circuit "InstanceUpdate" { firrtl.module @Foo(in %i : !firrtl.uint<1>) {} - // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.module @InstanceUpdate(in %foo_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%foo_ClockDomain]) { + // CHECK: %foo_ClockDomain_0, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> - // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %foo_ClockDomain_0, %foo_ClockDomain : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) @@ -252,10 +252,10 @@ firrtl.circuit "InstanceChoiceUpdate" { firrtl.module @Bar(in %i : !firrtl.uint<1>) {} firrtl.module @Baz(in %i : !firrtl.uint<1>) {} - // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.module @InstanceChoiceUpdate(in %inst_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%inst_ClockDomain]) { + // CHECK: %inst_ClockDomain_0, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain]) // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> - // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %inst_ClockDomain_0, %inst_ClockDomain : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) diff --git a/test/Dialect/FIRRTL/infer-domains-infer.mlir b/test/Dialect/FIRRTL/infer-domains-infer.mlir index f97b33abd0f3..271d9fd64dbd 100644 --- a/test/Dialect/FIRRTL/infer-domains-infer.mlir +++ b/test/Dialect/FIRRTL/infer-domains-infer.mlir @@ -5,9 +5,9 @@ firrtl.circuit "InferOutputDomain" { firrtl.extmodule @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) - // CHECK: firrtl.module private @Bar(out %ClockDomain: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%ClockDomain]) { - // CHECK: %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) - // CHECK: firrtl.domain.define %ClockDomain, %foo_D : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.module private @Bar(out %foo_D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%foo_D]) { + // CHECK: %foo_D_0, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D]) + // CHECK: firrtl.domain.define %foo_D, %foo_D_0 : !firrtl.domain<@ClockDomain()> // CHECK: firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1> // CHECK: } firrtl.module private @Bar(out %x : !firrtl.uint<1>) { @@ -16,9 +16,9 @@ firrtl.circuit "InferOutputDomain" { } // CHECK: firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) { - // CHECK: %bar_ClockDomain, %bar_x = firrtl.instance bar @Bar(out ClockDomain: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: %bar_foo_D, %bar_x = firrtl.instance bar @Bar(out foo_D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [foo_D]) // CHECK: firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1> - // CHECK: firrtl.domain.define %D, %bar_ClockDomain : !firrtl.domain<@ClockDomain()> + // CHECK: firrtl.domain.define %D, %bar_foo_D : !firrtl.domain<@ClockDomain()> // CHECK: } firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) { %bar_x = firrtl.instance bar @Bar(out x : !firrtl.uint<1>)