Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 106 additions & 26 deletions lib/Dialect/FIRRTL/Transforms/InferDomains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,13 @@ using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
/// port index.
using PendingExports = llvm::MapVector<DomainValue, unsigned>;

/// A map from domain values to their origin (instance name + port name).
/// Used to generate better names for inferred domain ports.
using DomainOriginTable = DenseMap<DomainValue, std::string>;

/// A map from terms to their origin. Used to track origins through unification.
using TermOriginTable = DenseMap<Term *, std::string>;

namespace {
struct PendingUpdates {
PortInsertions insertions;
Expand All @@ -981,6 +988,8 @@ struct PendingUpdates {
/// If `var` is not solved, solve it by recording a pending input port at
/// the indicated insertion point.
static void ensureSolved(const DomainInfo &info, Namespace &ns,
const DomainOriginTable &valueOrigins,
const TermOriginTable &termOrigins,
DomainTypeID typeID, size_t ip, LocationAttr loc,
VariableTerm *var, PendingUpdates &pending) {
if (pending.solutions.contains(var))
Expand All @@ -990,7 +999,21 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns,
auto domainDecl = info.getDomain(typeID);
auto domainName = domainDecl.getNameAttr();

auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
// Try to find an origin name for this variable. Check both the term origin
// table and the value origin table.
StringRef baseName = domainName.getValue();

// First, check if this term itself has an origin.
Term *root = find(var);
if (auto it = termOrigins.find(root); it != termOrigins.end()) {
baseName = it->second;
} else if (auto *val = dyn_cast<ValueTerm>(root)) {
// If the term is a value, check the value origin table.
if (auto it = valueOrigins.find(val->value); it != valueOrigins.end())
baseName = it->second;
}

auto portName = StringAttr::get(context, ns.newName(baseName));
auto portType = DomainType::getFromDomainOp(domainDecl);
auto portDirection = Direction::In;
auto portSym = StringAttr();
Expand All @@ -1011,6 +1034,7 @@ static void ensureSolved(const DomainInfo &info, Namespace &ns,
// an
/// output port.
static void ensureExported(const DomainInfo &info, Namespace &ns,
const DomainOriginTable &origins,
const ExportTable &exports, DomainTypeID typeID,
size_t ip, LocationAttr loc, ValueTerm *val,
PendingUpdates &pending) {
Expand All @@ -1025,7 +1049,12 @@ static void ensureExported(const DomainInfo &info, Namespace &ns,
auto domainDecl = info.getDomain(typeID);
auto domainName = domainDecl.getNameAttr();

auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
// Use origin name if available, otherwise fall back to domain name.
StringRef baseName = domainName.getValue();
if (auto it = origins.find(value); it != origins.end())
baseName = it->second;

auto portName = StringAttr::get(context, ns.newName(baseName));
auto portType = DomainType::getFromDomainOp(domainDecl);
auto portDirection = Direction::Out;
auto portSym = StringAttr();
Expand All @@ -1039,57 +1068,66 @@ static void ensureExported(const DomainInfo &info, Namespace &ns,
pending.insertions.push_back({ip, portInfo});
}

static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info,
Namespace &ns,
PendingUpdates &pending,
DomainTypeID typeID, size_t ip,
LocationAttr loc, Term *term,
const ExportTable &exports) {
static void getUpdatesForDomainAssociationOfPort(
const DomainInfo &info, Namespace &ns,
const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins,
PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc,
Term *term, const ExportTable &exports) {
if (auto *var = dyn_cast<VariableTerm>(term)) {
ensureSolved(info, ns, typeID, ip, loc, var, pending);
ensureSolved(info, ns, valueOrigins, termOrigins, typeID, ip, loc, var,
pending);
return;
}
if (auto *val = dyn_cast<ValueTerm>(term)) {
ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
ensureExported(info, ns, valueOrigins, exports, typeID, ip, loc, val,
pending);
return;
}
llvm_unreachable("invalid domain association");
}

static void getUpdatesForDomainAssociationOfPort(
const DomainInfo &info, Namespace &ns, const ExportTable &exports,
size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
const DomainInfo &info, Namespace &ns,
const DomainOriginTable &valueOrigins, const TermOriginTable &termOrigins,
const ExportTable &exports, size_t ip, LocationAttr loc, RowTerm *row,
PendingUpdates &pending) {
for (auto [index, term] : llvm::enumerate(row->elements))
getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index},
ip, loc, find(term), exports);
getUpdatesForDomainAssociationOfPort(info, ns, valueOrigins, termOrigins,
pending, DomainTypeID{index}, ip, loc,
find(term), exports);
}

static void getUpdatesForModulePorts(const DomainInfo &info,
TermAllocator &allocator,
const ExportTable &exports,
DomainTable &table, Namespace &ns,
FModuleOp moduleOp,
PendingUpdates &pending) {
static void
getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator,
const ExportTable &exports, DomainTable &table,
Namespace &ns, const DomainOriginTable &valueOrigins,
const TermOriginTable &termOrigins, FModuleOp moduleOp,
PendingUpdates &pending) {
for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
auto port = moduleOp.getArgument(i);
auto type = port.getType();
if (!isa<FIRRTLBaseType>(type))
continue;
getUpdatesForDomainAssociationOfPort(
info, ns, exports, i, moduleOp.getPortLocation(i),
info, ns, valueOrigins, termOrigins, exports, i,
moduleOp.getPortLocation(i),
getDomainAssociationAsRow(info, allocator, table, port), pending);
}
}

static void getUpdatesForModule(const DomainInfo &info,
TermAllocator &allocator,
const ExportTable &exports, DomainTable &table,
FModuleOp mod, PendingUpdates &pending) {
FModuleOp mod, PendingUpdates &pending,
const DomainOriginTable &valueOrigins,
const TermOriginTable &termOrigins) {
Namespace ns;
auto names = mod.getPortNamesAttr();
for (auto name : names.getAsRange<StringAttr>())
ns.add(name);
getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending);

getUpdatesForModulePorts(info, allocator, exports, table, ns, valueOrigins,
termOrigins, mod, pending);
}

static void applyUpdatesToModule(const DomainInfo &info,
Expand Down Expand Up @@ -1282,8 +1320,7 @@ static LogicalResult updateInstance(const DomainInfo &info,
return success();
}

/// After updating the port domain associations, walk the body of the moduleOp
/// to fix up any child instance modules.
/// Update the module body by creating domain.define operations for instances.
static LogicalResult updateModuleBody(const DomainInfo &info,
TermAllocator &allocator,
DomainTable &table, FModuleOp moduleOp) {
Expand All @@ -1299,13 +1336,56 @@ static LogicalResult updateModuleBody(const DomainInfo &info,
return failure(result.wasInterrupted());
}

/// Build domain origin tables by walking instances. This is done in a separate
/// function so it can be called before getUpdatesForModule.
static void buildDomainOriginTables(const DomainTable &table, FModuleOp mod,
DomainOriginTable &valueOrigins,
TermOriginTable &termOrigins) {
mod.getBodyBlock()->walk([&](FInstanceLike op) {
auto instanceName = op.getInstanceName();
for (size_t i = 0, e = op->getNumResults(); i < e; ++i) {
auto port = dyn_cast<DomainValue>(op->getResult(i));
if (!port)
continue;

auto portName = op.getPortNameAttr(i);
std::string originName = (instanceName + "_" + portName.getValue()).str();

// Track this origin for the domain value (use first occurrence).
if (!valueOrigins.count(port))
valueOrigins[port] = originName;

// Also track the origin for the term associated with this port.
// This is crucial for input ports where the term might be a variable.
if (auto *term = table.getOptTermForDomain(port)) {
Term *root = find(term);
if (!termOrigins.count(root))
termOrigins[root] = originName;

// If the term is a ValueTerm, also track the value.
if (auto *val = dyn_cast<ValueTerm>(root)) {
if (!valueOrigins.count(val->value))
valueOrigins[val->value] = originName;
}
}
}
});
}

/// Write the domain associations recorded in the domain table back to the IR.
static LogicalResult updateModule(const DomainInfo &info,
TermAllocator &allocator, DomainTable &table,
ModuleUpdateTable &updates, FModuleOp op) {
// Build domain origin tables by walking instances. This will be used for
// generating better port names.
DomainOriginTable valueOrigins;
TermOriginTable termOrigins;
buildDomainOriginTables(table, op, valueOrigins, termOrigins);

auto exports = initializeExportTable(table, op);
PendingUpdates pending;
getUpdatesForModule(info, allocator, exports, table, op, pending);
getUpdatesForModule(info, allocator, exports, table, op, pending,
valueOrigins, termOrigins);
applyUpdatesToModule(info, allocator, exports, table, op, pending);

// Update the domain info for the moduleOp's ports.
Expand Down
18 changes: 9 additions & 9 deletions test/Dialect/FIRRTL/infer-domains-infer-all.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,16 @@ firrtl.circuit "ExportDomain" {
)

firrtl.module @ExportDomain(
// CHECK: out %ClockDomain: !firrtl.domain<@ClockDomain()>
// CHECK: out %o: !firrtl.uint<1> domains [%ClockDomain]
// CHECK: out %foo_A: !firrtl.domain<@ClockDomain()>
// CHECK: out %o: !firrtl.uint<1> domains [%foo_A]
out %o: !firrtl.uint<1>
) {
%foo_A, %foo_o = firrtl.instance foo @Foo(
out A: !firrtl.domain<@ClockDomain()>,
out o: !firrtl.uint<1> domains [A]
)
firrtl.matchingconnect %o, %foo_o : !firrtl.uint<1>
// CHECK: firrtl.domain.define %ClockDomain, %foo_A : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.domain.define %foo_A, %foo_A_0 : !firrtl.domain<@ClockDomain()>
}
}

Expand Down Expand Up @@ -228,10 +228,10 @@ firrtl.circuit "InstanceUpdate" {

firrtl.module @Foo(in %i : !firrtl.uint<1>) {}

// CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) {
// CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain])
// CHECK: firrtl.module @InstanceUpdate(in %foo_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%foo_ClockDomain]) {
// CHECK: %foo_ClockDomain_0, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain])
// CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1>
// CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.domain.define %foo_ClockDomain_0, %foo_ClockDomain : !firrtl.domain<@ClockDomain()>
// CHECK: }
firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) {
%foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>)
Expand All @@ -252,10 +252,10 @@ firrtl.circuit "InstanceChoiceUpdate" {
firrtl.module @Bar(in %i : !firrtl.uint<1>) {}
firrtl.module @Baz(in %i : !firrtl.uint<1>) {}

// CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%ClockDomain]) {
// CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain])
// CHECK: firrtl.module @InstanceChoiceUpdate(in %inst_ClockDomain: !firrtl.domain<@ClockDomain()>, in %i: !firrtl.uint<1> domains [%inst_ClockDomain]) {
// CHECK: %inst_ClockDomain_0, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain<@ClockDomain()>, in i: !firrtl.uint<1> domains [ClockDomain])
// CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1>
// CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.domain.define %inst_ClockDomain_0, %inst_ClockDomain : !firrtl.domain<@ClockDomain()>
// CHECK: }
firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) {
%inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>)
Expand Down
10 changes: 5 additions & 5 deletions test/Dialect/FIRRTL/infer-domains-infer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ firrtl.circuit "InferOutputDomain" {

firrtl.extmodule @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D])

// CHECK: firrtl.module private @Bar(out %ClockDomain: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%ClockDomain]) {
// CHECK: %foo_D, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D])
// CHECK: firrtl.domain.define %ClockDomain, %foo_D : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.module private @Bar(out %foo_D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%foo_D]) {
// CHECK: %foo_D_0, %foo_x = firrtl.instance foo @Foo(out D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [D])
// CHECK: firrtl.domain.define %foo_D, %foo_D_0 : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.matchingconnect %x, %foo_x : !firrtl.uint<1>
// CHECK: }
firrtl.module private @Bar(out %x : !firrtl.uint<1>) {
Expand All @@ -16,9 +16,9 @@ firrtl.circuit "InferOutputDomain" {
}

// CHECK: firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) {
// CHECK: %bar_ClockDomain, %bar_x = firrtl.instance bar @Bar(out ClockDomain: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [ClockDomain])
// CHECK: %bar_foo_D, %bar_x = firrtl.instance bar @Bar(out foo_D: !firrtl.domain<@ClockDomain()>, out x: !firrtl.uint<1> domains [foo_D])
// CHECK: firrtl.matchingconnect %x, %bar_x : !firrtl.uint<1>
// CHECK: firrtl.domain.define %D, %bar_ClockDomain : !firrtl.domain<@ClockDomain()>
// CHECK: firrtl.domain.define %D, %bar_foo_D : !firrtl.domain<@ClockDomain()>
// CHECK: }
firrtl.module @InferOutputDomain(out %D: !firrtl.domain<@ClockDomain()>, out %x: !firrtl.uint<1> domains [%D]) {
%bar_x = firrtl.instance bar @Bar(out x : !firrtl.uint<1>)
Expand Down
Loading