Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions src/contraction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <unordered_set>

#include "errors.h"
// NEW
#include "talshxx.hpp"
// NEW UP TO HERE

namespace qflex {

Expand Down Expand Up @@ -233,7 +236,12 @@ void ContractionData::ContractGrid(
std::string result_id = result_space(patch_max_size_[op.expand.id]);
Tensor& result = get_scratch(result_id);
try {
multiply(prev, next, result, get_scratch(kGeneralSpace).data());
// NEW REMOVE
//multiply(prev, next, result, get_scratch(kGeneralSpace).data());
// NEW REMOVE UP TO HERE
// NEW
multiply_with_talsh(prev, next, result);
// NEW UP TO HERE
} catch (const std::string& err_msg) {
throw ERROR_MSG("Failed to call multiply(). Error:\n\t[", err_msg,
"]");
Expand Down Expand Up @@ -318,7 +326,12 @@ void ContractionData::ContractGrid(
result_space(patch_max_size_[op.merge.target_id]);
Tensor& result = get_scratch(result_id);
try {
multiply(patch_1, patch_2, result, get_scratch(kGeneralSpace).data());
// NEW REMOVE
//multiply(patch_1, patch_2, result, get_scratch(kGeneralSpace).data());
// NEW REMOVE UP TO HERE
// NEW
multiply_with_talsh(patch_1, patch_2, result);
// NEW UP TO HERE
} catch (const std::string& err_msg) {
throw ERROR_MSG("Failed to call multiply(). Error:\n\t[", err_msg,
"]");
Expand Down Expand Up @@ -622,10 +635,112 @@ void ContractGrid(const std::list<ContractionOperation>& ordering,
active_patches[patch] = false;
}
try {
// NEW
std::size_t init_space(10000000000);
talsh::initialize(&init_space);
std::cout << "Initialized TALSH.\n";
// NEW UP TO HERE
data.ContractGrid(ordering, /*output_index = */ 0, active_patches);
// NEW
talsh::shutdown();
std::cout << "Shut down TALSH.\n";
// NEW UP TO HERE
} catch (const std::string& err_msg) {
throw ERROR_MSG("Failed to call ContractGrid(). Error:\n\t[", err_msg, "]");
}
}

// NEW

// Concatenate strings separated by commas in the result.
std::string comma_concatenate_reversed(
const std::vector<std::string>& strings,
const std::unordered_map<std::string, std::string>& index_letter) {
std::string result_string = "";
if (strings.size() > 0) {
for (size_t i = strings.size() - 1; i > 0; --i) {
result_string += index_letter.at(strings[i]) + ",";
}
result_string += index_letter.at(strings[0]);
}
return result_string;
}

void multiply_with_talsh(Tensor& A, Tensor& B, Tensor& C) {
// Remove this and call _ALPHABET from tensor.cpp
const std::vector<std::string> _ALPHABET(
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M",
"N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"});
if (A.data() == C.data()) {
throw ERROR_MSG("A and C cannot be the same tensor: ",
C.tensor_to_string());
}
if (B.data() == C.data()) {
throw ERROR_MSG("B and C cannot be the same tensor: ",
C.tensor_to_string());
}

// Get unique indices.
std::unordered_set<std::string> unique_indices;
std::vector<Tensor*> tensors({&A, &B});
for (auto T : tensors) {
for (const auto& index : T->get_indices()) {
unique_indices.insert(index);
}
}
std::unordered_map<std::string, std::string> index_letter;
{
std::size_t alphabet_pos = 0;
for (const auto& index : unique_indices) {
index_letter[index] = _ALPHABET[alphabet_pos];
++alphabet_pos;
}
}

// Update dimensions and indexes on qFlex tensors
multiply(A, B, C, nullptr, true);

// Create contraction string, reversing each tensors indexes, so that TALSH
// contracts properly, since it follows FORTRAN convention.
std::string contraction_string = "D(";
std::string index;

contraction_string += comma_concatenate_reversed(C.get_indices(),
index_letter);
contraction_string += ")+=L(";
contraction_string += comma_concatenate_reversed(A.get_indices(),
index_letter);
contraction_string += ")*R(";
contraction_string += comma_concatenate_reversed(B.get_indices(),
index_letter);
contraction_string += ")";

bool done;
int errc;
std::vector<int> signature_D;
for (int i = C.get_dimensions().size() - 1; i >= 0; --i) {
size_t dim = C.get_dimensions()[i];
signature_D.push_back(dim);
}
talsh::Tensor D(signature_D, C.data());
std::vector<int> signature_L;
for (int i = A.get_dimensions().size() - 1; i >= 0; --i) {
size_t dim = A.get_dimensions()[i];
signature_L.push_back(dim);
}
talsh::Tensor L(signature_L, A.data());
std::vector<int> signature_R;
for (int i = B.get_dimensions().size() - 1; i >= 0; --i) {
size_t dim = B.get_dimensions()[i];
signature_R.push_back(dim);
}
talsh::Tensor R(signature_R, B.data());
errc = D.contractAccumulate(nullptr, contraction_string, L, R, DEV_HOST,
0, s_type(1.0), false);
done = D.sync();
}
// NEW UP TO HERE

} // namespace qflex
4 changes: 4 additions & 0 deletions src/contraction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ void ContractGrid(const std::list<ContractionOperation>& ordering,
std::vector<std::vector<Tensor>>* tensor_grid,
std::vector<std::complex<double>>* amplitudes);

// NEW
void multiply_with_talsh(Tensor& A, Tensor& B, Tensor& C);
// NEW UP TO HERE

} // namespace qflex

#endif // CONTRACTION_UTILS_
20 changes: 19 additions & 1 deletion src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,10 +960,22 @@ void _multiply_vv(const s_type* A_data, const s_type* B_data, s_type* C_data,
cblas_cdotu_sub(k, A_data, 1, B_data, 1, C_data);
}

void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy) {
// NEW ERASES
//void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy) {
// NEW ERASES UP TO HERE`
// NEW
void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy,
bool dry /*= false*/) {
// NEW UP TO HERE`
// NEW
if (!dry) {
// NEW UP TO HERE
if (scratch_copy == nullptr) {
throw ERROR_MSG("Scratch copy must be non-null.");
}
// NEW
}
// NEW UP TO HERE

if (A.data() == C.data()) {
throw ERROR_MSG("A and C cannot be the same tensor: ",
Expand Down Expand Up @@ -1020,6 +1032,9 @@ void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy) {

if (global::verbose > 1) t0 = std::chrono::high_resolution_clock::now();

// NEW
if (!dry) {
// NEW UP TO HERE
// Reorder.
std::vector<std::string> A_new_ordering =
_vector_union(left_indices, common_indices);
Expand Down Expand Up @@ -1096,6 +1111,9 @@ void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy) {
std::cerr << "Time multiplying A*B: " << time_span.count() << "s\n";
t0 = std::chrono::high_resolution_clock::now();
}
// NEW
}
// NEW UP TO HERE

// Set indices and dimensions of C.
std::vector<std::string> C_indices =
Expand Down
8 changes: 7 additions & 1 deletion src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,13 @@ void _multiply_vv(const s_type* A_data, const s_type* B_data, s_type* C_data,
* reordering. It has to allocate at least as much max(A.size(), B.size())
* memory.
*/
void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy);
// NEW ERASES
//void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy);
// NEW ERASES UP TO HERE
// NEW
void multiply(Tensor& A, Tensor& B, Tensor& C, s_type* scratch_copy,
bool dry = false);
// NEW UP TO HERE

/**
* Returns the size of the tensor product of A and B.
Expand Down