Skip to content

Commit

Permalink
[RTG][Elaboration] Use malloc instead of IR for virtual registers and…
Browse files Browse the repository at this point in the history
… labels
  • Loading branch information
maerhart committed Feb 13, 2025
1 parent b34c379 commit d748483
Showing 1 changed file with 67 additions and 78 deletions.
145 changes: 67 additions & 78 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,56 +89,30 @@ struct BagStorage;
struct SequenceStorage;
struct RandomizedSequenceStorage;
struct SetStorage;
struct VirtualRegisterStorage;
struct UniqueLabelStorage;

/// Represents a unique virtual register.
struct VirtualRegister {
VirtualRegister(uint64_t id, ArrayAttr allowedRegs)
: id(id), allowedRegs(allowedRegs) {}

bool operator==(const VirtualRegister &other) const {
assert(
id != other.id ||
allowedRegs == other.allowedRegs &&
"instances with the same ID must have the same allowed registers");
return id == other.id;
}

// The ID of this virtual register.
uint64_t id;

// The list of fixed registers allowed to be selected for this virtual
// register.
ArrayAttr allowedRegs;
};

/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
/// a label declaration instead of calling the builtin dialect constant
/// materializer.
struct LabelValue {
LabelValue(StringAttr name, uint64_t id = 0) : name(name), id(id) {}
LabelValue(StringAttr name) : name(name) {}

bool operator==(const LabelValue &other) const {
return name == other.name && id == other.id;
}
bool operator==(const LabelValue &other) const { return name == other.name; }

/// The label name. For unique labels, this is just the prefix.
/// The label name.
StringAttr name;

/// Standard label declarations always have id=0
uint64_t id;
};

/// The abstract base class for elaborated values.
using ElaboratorValue =
std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
RandomizedSequenceStorage *, SetStorage *, VirtualRegister,
LabelValue>;

// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const VirtualRegister &val) {
return llvm::hash_value(val.id);
}
RandomizedSequenceStorage *, SetStorage *,
VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue>;

// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const LabelValue &val) {
return llvm::hash_combine(val.id, val.name);
return llvm::hash_value(val.name);
}

// NOLINTNEXTLINE(readability-identifier-naming)
Expand All @@ -164,32 +138,16 @@ struct DenseMapInfo<bool> {

static bool isEqual(const bool &lhs, const bool &rhs) { return lhs == rhs; }
};

template <>
struct DenseMapInfo<VirtualRegister> {
static inline VirtualRegister getEmptyKey() {
return VirtualRegister(0, ArrayAttr());
}
static inline VirtualRegister getTombstoneKey() {
return VirtualRegister(~0, ArrayAttr());
}
static unsigned getHashValue(const VirtualRegister &val) {
return llvm::hash_combine(val.id, val.allowedRegs);
}

static bool isEqual(const VirtualRegister &lhs, const VirtualRegister &rhs) {
return lhs == rhs;
}
};

template <>
struct DenseMapInfo<LabelValue> {
static inline LabelValue getEmptyKey() { return LabelValue(StringAttr(), 0); }
static inline LabelValue getEmptyKey() {
return DenseMapInfo<StringAttr>::getEmptyKey();
}
static inline LabelValue getTombstoneKey() {
return LabelValue(StringAttr(), ~0);
return DenseMapInfo<StringAttr>::getTombstoneKey();
}
static unsigned getHashValue(const LabelValue &val) {
return llvm::hash_combine(val.name, val.id);
return hash_value(val);
}

static bool isEqual(const LabelValue &lhs, const LabelValue &rhs) {
Expand Down Expand Up @@ -351,6 +309,28 @@ struct RandomizedSequenceStorage {
const SequenceStorage *sequence;
};

/// Represents a unique virtual register.
struct VirtualRegisterStorage {
VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}

// NOTE: we don't need an 'isEqual' function and 'hashcode' here because
// VirtualRegisters are never internalized.

// The list of fixed registers allowed to be selected for this virtual
// register.
const ArrayAttr allowedRegs;
};

struct UniqueLabelStorage {
UniqueLabelStorage(StringAttr name) : name(name) {}

// NOTE: we don't need an 'isEqual' function and 'hashcode' here because
// VirtualRegisters are never internalized.

/// The label name. For unique labels, this is just the prefix.
const StringAttr name;
};

/// An 'Internalizer' object internalizes storages and takes ownership of them.
/// When the initializer object is destroyed, all owned storages are also
/// deallocated and thus must not be accessed anymore.
Expand All @@ -375,6 +355,12 @@ class Internalizer {
return storagePtr;
}

template <typename StorageTy, typename... Args>
StorageTy *create(Args &&...args) {
return new (allocator.Allocate<StorageTy>())
StorageTy(std::forward<Args>(args)...);
}

private:
template <typename StorageTy>
DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &
Expand Down Expand Up @@ -459,12 +445,16 @@ static void print(SetStorage *val, llvm::raw_ostream &os) {
os << "} at " << val << ">";
}

static void print(const VirtualRegister &val, llvm::raw_ostream &os) {
os << "<virtual-register " << val.id << " " << val.allowedRegs << ">";
static void print(const VirtualRegisterStorage *val, llvm::raw_ostream &os) {
os << "<virtual-register " << val << " " << val->allowedRegs << ">";
}

static void print(const UniqueLabelStorage *val, llvm::raw_ostream &os) {
os << "<unique-label " << val << " " << val->name << ">";
}

static void print(const LabelValue &val, llvm::raw_ostream &os) {
os << "<label " << val.id << " " << val.name << ">";
os << "<label " << val.name << ">";
}

static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
Expand Down Expand Up @@ -690,24 +680,26 @@ class Materializer {
return builder.create<RandomizeSequenceOp>(loc, seq);
}

Value visit(const VirtualRegister &val, Location loc,
Value visit(VirtualRegisterStorage *val, Location loc,
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
auto res = builder.create<VirtualRegisterOp>(loc, val.allowedRegs);
Value res = builder.create<VirtualRegisterOp>(loc, val->allowedRegs);
materializedValues[val] = res;
return res;
}

Value visit(const LabelValue &val, Location loc,
Value visit(UniqueLabelStorage *val, Location loc,
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
if (val.id == 0) {
auto res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
materializedValues[val] = res;
return res;
}
Value res = builder.create<LabelUniqueDeclOp>(loc, val->name, ValueRange());
materializedValues[val] = res;
return res;
}

auto res = builder.create<LabelUniqueDeclOp>(loc, val.name, ValueRange());
Value visit(const LabelValue &val, Location loc,
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
Value res = builder.create<LabelDeclOp>(loc, val.name, ValueRange());
materializedValues[val] = res;
return res;
}
Expand Down Expand Up @@ -749,9 +741,6 @@ struct ElaboratorSharedState {
/// The worklist used to keep track of the test and sequence operations to
/// make sure they are processed top-down (BFS traversal).
std::queue<RandomizedSequenceStorage *> worklist;

uint64_t virtualRegisterID = 0;
uint64_t uniqueLabelID = 1;
};

/// A collection of state per RTG test.
Expand Down Expand Up @@ -1023,8 +1012,9 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
}

FailureOr<DeletionKind> visitOp(VirtualRegisterOp op) {
state[op.getResult()] = VirtualRegister(sharedState.virtualRegisterID++,
op.getAllowedRegsAttr());
state[op.getResult()] =
sharedState.internalizer.create<VirtualRegisterStorage>(
op.getAllowedRegsAttr());
return DeletionKind::Delete;
}

Expand Down Expand Up @@ -1055,9 +1045,8 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
}

FailureOr<DeletionKind> visitOp(LabelUniqueDeclOp op) {
state[op.getLabel()] = LabelValue(
substituteFormatString(op.getFormatStringAttr(), op.getArgs()),
sharedState.uniqueLabelID++);
state[op.getLabel()] = sharedState.internalizer.create<UniqueLabelStorage>(
substituteFormatString(op.getFormatStringAttr(), op.getArgs()));
return DeletionKind::Delete;
}

Expand Down

0 comments on commit d748483

Please # to comment.