Skip to content

Commit b59e1b3

Browse files
Use RAII principles for postgres
1 parent e07be25 commit b59e1b3

11 files changed

Lines changed: 168 additions & 95 deletions

File tree

include/sqlgen/postgres/Connection.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,32 @@
1515
#include "../dynamic/Column.hpp"
1616
#include "../dynamic/Statement.hpp"
1717
#include "../dynamic/Write.hpp"
18+
#include "../internal/iterator_t.hpp"
1819
#include "../internal/to_container.hpp"
1920
#include "../internal/write_or_insert.hpp"
2021
#include "../is_connection.hpp"
2122
#include "../sqlgen_api.hpp"
2223
#include "../transpilation/value_t.hpp"
2324
#include "Credentials.hpp"
2425
#include "Iterator.hpp"
26+
#include "PostgresV2Connection.hpp"
2527
#include "exec.hpp"
2628
#include "to_sql.hpp"
2729

2830
namespace sqlgen::postgres {
2931

3032
class SQLGEN_API Connection {
31-
using ConnPtr = Ref<PGconn>;
33+
using Conn = PostgresV2Connection;
3234

3335
public:
36+
Connection(const Conn& _conn);
37+
3438
Connection(const Credentials& _credentials);
3539

3640
static rfl::Result<Ref<Connection>> make(
3741
const Credentials& _credentials) noexcept;
3842

39-
~Connection();
43+
~Connection() = default;
4044

4145
Result<Nothing> begin_transaction() noexcept;
4246

@@ -82,8 +86,6 @@ class SQLGEN_API Connection {
8286
const std::vector<std::vector<std::optional<std::string>>>&
8387
_data) noexcept;
8488

85-
static ConnPtr make_conn(const std::string& _conn_str);
86-
8789
Result<Ref<Iterator>> read_impl(const dynamic::SelectFrom& _query);
8890

8991
std::string to_buffer(
@@ -93,9 +95,7 @@ class SQLGEN_API Connection {
9395
const std::vector<std::vector<std::optional<std::string>>>& _data);
9496

9597
private:
96-
ConnPtr conn_;
97-
98-
Credentials credentials_;
98+
Conn conn_;
9999
};
100100

101101
static_assert(is_connection<Connection>,

include/sqlgen/postgres/Iterator.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
#include "../Ref.hpp"
1212
#include "../Result.hpp"
1313
#include "../sqlgen_api.hpp"
14+
#include "PostgresV2Connection.hpp"
15+
#include "PostgresV2Result.hpp"
1416

1517
namespace sqlgen::postgres {
1618

1719
class SQLGEN_API Iterator {
18-
using ConnPtr = Ref<PGconn>;
20+
using Conn = PostgresV2Connection;
1921

2022
public:
21-
Iterator(const std::string& _sql, const ConnPtr& _conn);
23+
Iterator(const std::string& _sql, const Conn& _conn);
2224

2325
Iterator(const Iterator& _other) = delete;
2426

@@ -54,7 +56,7 @@ class SQLGEN_API Iterator {
5456

5557
/// The underlying postgres connection. We have this in here to prevent its
5658
/// destruction for the lifetime of the iterator.
57-
ConnPtr conn_;
59+
Conn conn_;
5860

5961
/// Whether the end is reached.
6062
bool end_;
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef SQLGEN_POSTGRES_POSTGRESV2CONNECTION_HPP_
2+
#define SQLGEN_POSTGRES_POSTGRESV2CONNECTION_HPP_
3+
4+
#include <libpq-fe.h>
5+
6+
#include <memory>
7+
#include <rfl.hpp>
8+
#include <stdexcept>
9+
10+
#include "../Ref.hpp"
11+
#include "../Result.hpp"
12+
#include "../sqlgen_api.hpp"
13+
14+
namespace sqlgen::postgres {
15+
16+
class SQLGEN_API PostgresV2Connection {
17+
public:
18+
PostgresV2Connection(PGconn* _ptr)
19+
: ptr_(Ref<PGconn>::make(std::shared_ptr<PGconn>(_ptr, &PQfinish))
20+
.value()) {}
21+
22+
~PostgresV2Connection() = default;
23+
24+
static rfl::Result<PostgresV2Connection> make(
25+
const std::string& _conn_str) noexcept;
26+
27+
PGconn* ptr() const { return ptr_.get(); }
28+
29+
private:
30+
Ref<PGconn> ptr_;
31+
};
32+
33+
} // namespace sqlgen::postgres
34+
35+
#endif
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef SQLGEN_POSTGRES_POSTGRESV2RESULT_HPP_
2+
#define SQLGEN_POSTGRES_POSTGRESV2RESULT_HPP_
3+
4+
#include <libpq-fe.h>
5+
6+
#include <memory>
7+
#include <rfl.hpp>
8+
#include <stdexcept>
9+
10+
#include "../Ref.hpp"
11+
#include "../Result.hpp"
12+
#include "../sqlgen_api.hpp"
13+
#include "PostgresV2Connection.hpp"
14+
15+
namespace sqlgen::postgres {
16+
17+
class SQLGEN_API PostgresV2Result {
18+
public:
19+
PostgresV2Result(PGresult* _ptr)
20+
: ptr_(Ref<PGresult>::make(std::shared_ptr<PGresult>(_ptr, &PQclear))
21+
.value()) {}
22+
23+
~PostgresV2Result() = default;
24+
25+
static rfl::Result<PostgresV2Result> make(
26+
const std::string& _query, const PostgresV2Connection& _conn) noexcept;
27+
28+
PGresult* ptr() const { return ptr_.get(); }
29+
30+
private:
31+
Ref<PGresult> ptr_;
32+
};
33+
34+
} // namespace sqlgen::postgres
35+
36+
#endif

include/sqlgen/postgres/exec.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
#include "../Ref.hpp"
1010
#include "../Result.hpp"
1111
#include "../sqlgen_api.hpp"
12+
#include "PostgresV2Connection.hpp"
13+
#include "PostgresV2Result.hpp"
1214

1315
namespace sqlgen::postgres {
1416

15-
Result<Ref<PGresult>> SQLGEN_API exec(const Ref<PGconn>& _conn,
16-
const std::string& _sql) noexcept;
17+
Result<PostgresV2Result> SQLGEN_API exec(const PostgresV2Connection& _conn,
18+
const std::string& _sql) noexcept;
1719

1820
} // namespace sqlgen::postgres
1921

src/sqlgen/postgres/Connection.cpp

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
#include "sqlgen/internal/random.hpp"
1010
#include "sqlgen/internal/strings/strings.hpp"
1111
#include "sqlgen/postgres/Iterator.hpp"
12+
#include "sqlgen/postgres/PostgresV2Result.hpp"
1213

1314
namespace sqlgen::postgres {
1415

15-
Connection::Connection(const Credentials& _credentials)
16-
: conn_(make_conn(_credentials.to_str())), credentials_(_credentials) {}
16+
Connection::Connection(const Conn& _conn) : conn_(_conn) {}
1717

18-
Connection::~Connection() = default;
18+
Connection::Connection(const Credentials& _credentials)
19+
: conn_(PostgresV2Connection::make(_credentials.to_str()).value()) {}
1920

2021
Result<Nothing> Connection::begin_transaction() noexcept {
2122
return execute("BEGIN TRANSACTION;");
@@ -24,20 +25,19 @@ Result<Nothing> Connection::begin_transaction() noexcept {
2425
Result<Nothing> Connection::commit() noexcept { return execute("COMMIT;"); }
2526

2627
Result<Nothing> Connection::execute(const std::string& _sql) noexcept {
27-
return exec(conn_, _sql).transform([](auto&&) { return Nothing{}; });
28+
return PostgresV2Result::make(_sql, conn_).transform([](auto&&) {
29+
return Nothing{};
30+
});
2831
}
2932

3033
Result<Nothing> Connection::end_write() {
31-
if (PQputCopyEnd(conn_.get(), NULL) == -1) {
32-
return error(PQerrorMessage(conn_.get()));
34+
if (PQputCopyEnd(conn_.ptr(), NULL) == -1) {
35+
return error(PQerrorMessage(conn_.ptr()));
3336
}
34-
const auto res = PQgetResult(conn_.get());
35-
if (PQresultStatus(res) != PGRES_COMMAND_OK) {
36-
const auto err = error(PQerrorMessage(conn_.get()));
37-
PQclear(res);
38-
return err;
37+
const auto res = PostgresV2Result(PQgetResult(conn_.ptr()));
38+
if (PQresultStatus(res.ptr()) != PGRES_COMMAND_OK) {
39+
return error(PQerrorMessage(conn_.ptr()));
3940
}
40-
PQclear(res);
4141
return Nothing{};
4242
}
4343

@@ -53,14 +53,14 @@ Result<Nothing> Connection::insert_impl(
5353

5454
const auto sql = to_sql_impl(_stmt);
5555

56-
const auto res = PQprepare(conn_.get(), name.c_str(), sql.c_str(),
57-
_data.at(0).size(), nullptr);
56+
const auto res = PostgresV2Result(PQprepare(
57+
conn_.ptr(), name.c_str(), sql.c_str(), _data.at(0).size(), nullptr));
5858

59-
const auto status = PQresultStatus(res);
59+
const auto status = PQresultStatus(res.ptr());
6060

6161
if (status != PGRES_COMMAND_OK) {
6262
return error("Generating prepared statement for '" + sql +
63-
"' failed: " + PQresultErrorMessage(res));
63+
"' failed: " + PQresultErrorMessage(res.ptr()));
6464
}
6565

6666
std::vector<const char*> current_row(_data[0].size());
@@ -71,7 +71,6 @@ Result<Nothing> Connection::insert_impl(
7171
const auto& d = _data[i];
7272

7373
if (d.size() != current_row.size()) {
74-
execute("ROLLBACK;");
7574
execute("DEALLOCATE " + name + ";");
7675
return error("Error in entry " + std::to_string(i) + ": Expected " +
7776
std::to_string(current_row.size()) + " entries, got " +
@@ -82,60 +81,38 @@ Result<Nothing> Connection::insert_impl(
8281
current_row[j] = d[j] ? d[j]->c_str() : nullptr;
8382
}
8483

85-
const auto res = PQexecPrepared(conn_.get(), // conn
86-
name.c_str(), // stmtName
87-
n_params, // nParams
88-
current_row.data(), // paramValues
89-
nullptr, // paramLengths
90-
nullptr, // paramFormats
91-
0 // resultFormat
92-
);
84+
const auto res =
85+
PostgresV2Result(PQexecPrepared(conn_.ptr(), // conn
86+
name.c_str(), // stmtName
87+
n_params, // nParams
88+
current_row.data(), // paramValues
89+
nullptr, // paramLengths
90+
nullptr, // paramFormats
91+
0 // resultFormat
92+
));
93+
94+
const auto status = PQresultStatus(res.ptr());
9395

94-
const auto status = PQresultStatus(res);
9596
if (status != PGRES_COMMAND_OK) {
96-
PQclear(res);
9797
const auto err = error(std::string("Executing INSERT failed: ") +
98-
PQresultErrorMessage(res));
99-
execute("ROLLBACK;");
98+
PQresultErrorMessage(res.ptr()));
10099
execute("DEALLOCATE " + name + ";");
101100
return err;
102101
}
103-
PQclear(res);
104102
}
105103

106104
return execute("DEALLOCATE " + name + ";");
107105
}
108106

109107
rfl::Result<Ref<Connection>> Connection::make(
110108
const Credentials& _credentials) noexcept {
111-
try {
112-
return Ref<Connection>::make(_credentials);
113-
} catch (std::exception& e) {
114-
return error(e.what());
115-
}
116-
}
117-
118-
typename Connection::ConnPtr Connection::make_conn(
119-
const std::string& _conn_str) {
120-
const auto raw_ptr = PQconnectdb(_conn_str.c_str());
121-
122-
if (PQstatus(raw_ptr) != CONNECTION_OK) {
123-
const auto msg = std::string("Connection to postgres failed: ") +
124-
PQerrorMessage(raw_ptr);
125-
PQfinish(raw_ptr);
126-
throw std::runtime_error(msg.c_str());
127-
}
128-
129-
return ConnPtr::make(std::shared_ptr<PGconn>(raw_ptr, &PQfinish)).value();
109+
return PostgresV2Connection::make(_credentials.to_str())
110+
.transform([](auto&& _conn) { return Ref<Connection>::make(_conn); });
130111
}
131112

132113
Result<Ref<Iterator>> Connection::read_impl(const dynamic::SelectFrom& _query) {
133114
const auto sql = postgres::to_sql_impl(_query);
134-
try {
135-
return Ref<Iterator>::make(sql, conn_);
136-
} catch (std::exception& e) {
137-
return error(e.what());
138-
}
115+
return Ref<Iterator>::make(sql, conn_);
139116
}
140117

141118
Result<Nothing> Connection::rollback() noexcept { return execute("ROLLBACK;"); }
@@ -172,17 +149,15 @@ Result<Nothing> Connection::write_impl(
172149
const std::vector<std::vector<std::optional<std::string>>>& _data) {
173150
for (const auto& line : _data) {
174151
const auto buffer = to_buffer(line);
175-
const auto success = PQputCopyData(conn_.get(), buffer.c_str(),
152+
const auto success = PQputCopyData(conn_.ptr(), buffer.c_str(),
176153
static_cast<int>(buffer.size()));
177154
if (success != 1) {
178-
PQputCopyEnd(conn_.get(), NULL);
179-
while (auto res = PQgetResult(conn_.get()))
180-
PQclear(res);
181-
155+
PQputCopyEnd(conn_.ptr(), NULL);
182156
return error("Error occurred while writing data to postgres.");
183157
}
184158
}
185159
return Nothing{};
186160
}
187161

188162
} // namespace sqlgen::postgres
163+

src/sqlgen/postgres/Iterator.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace sqlgen::postgres {
1212

13-
Iterator::Iterator(const std::string& _sql, const ConnPtr& _conn)
13+
Iterator::Iterator(const std::string& _sql, const Conn& _conn)
1414
: cursor_name_(make_cursor_name()), conn_(_conn), end_(false) {
1515
exec(conn_, "BEGIN").value();
1616
exec(conn_, "DECLARE " + cursor_name_ + " CURSOR FOR " + _sql).value();
@@ -33,22 +33,22 @@ Result<std::vector<std::vector<std::optional<std::string>>>> Iterator::next(
3333
return error("End is reached.");
3434
}
3535

36-
const auto to_vector = [](const Ref<PGresult>& _res)
36+
const auto to_vector = [](const PostgresV2Result& _res)
3737
-> std::vector<std::vector<std::optional<std::string>>> {
38-
const int num_rows = PQntuples(_res.get());
39-
const int num_cols = PQnfields(_res.get());
38+
const int num_rows = PQntuples(_res.ptr());
39+
const int num_cols = PQnfields(_res.ptr());
4040

4141
std::vector<std::vector<std::optional<std::string>>> vec(num_rows);
4242

4343
for (int i = 0; i < num_rows; ++i) {
4444
std::vector<std::optional<std::string>> row(num_cols);
4545

4646
for (int j = 0; j < num_cols; ++j) {
47-
const bool is_null = PQgetisnull(_res.get(), i, j);
47+
const bool is_null = PQgetisnull(_res.ptr(), i, j);
4848
if (is_null) {
4949
row[j] = std::nullopt;
5050
} else {
51-
row[j] = std::string(PQgetvalue(_res.get(), i, j));
51+
row[j] = std::string(PQgetvalue(_res.ptr(), i, j));
5252
}
5353
}
5454

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "sqlgen/postgres/PostgresV2Connection.hpp"
2+
3+
namespace sqlgen::postgres {
4+
5+
rfl::Result<PostgresV2Connection> PostgresV2Connection::make(
6+
const std::string& _conn_str) noexcept {
7+
auto conn = PQconnectdb(_conn_str.c_str());
8+
if (PQstatus(conn) != CONNECTION_OK) {
9+
const auto msg =
10+
std::string("Connection to postgres failed: ") + PQerrorMessage(conn);
11+
PQfinish(conn);
12+
return error(msg);
13+
}
14+
return PostgresV2Connection(conn);
15+
}
16+
17+
} // namespace sqlgen::postgres

0 commit comments

Comments
 (0)