Skip to content

Commit ce7932a

Browse files
author
peng.li24
committed
feat: add auto-double py::array overloads — accept any numpy dtype, always produce Geometry<double>
1 parent 2b437e9 commit ce7932a

2 files changed

Lines changed: 69 additions & 0 deletions

File tree

pycpp/geometry_py.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@ py::array_t<T> native_to_array(const T* data, size_t rows, size_t cols) {
4343
return result;
4444
}
4545

46+
/// Convert any py::array coordinate buffer to std::vector<double>.
47+
/// dtype double → direct copy; float32/other → cast & widen to double.
48+
inline std::vector<double> array_to_double_vec(const py::array& arr) {
49+
auto buf = arr.request();
50+
size_t sz = static_cast<size_t>(buf.shape[0]) * 2;
51+
std::vector<double> tmp(sz);
52+
if (arr.dtype().is(py::dtype::of<double>())) {
53+
const double* src = static_cast<const double*>(buf.ptr);
54+
std::copy(src, src + sz, tmp.begin());
55+
} else {
56+
// float32 or other → cast through pybind11 to float32, then widen
57+
auto f32 = py::cast<py::array_t<float>>(arr);
58+
auto fbuf = f32.request();
59+
const float* src = static_cast<const float*>(fbuf.ptr);
60+
for (size_t i = 0; i < sz; ++i) tmp[i] = static_cast<double>(src[i]);
61+
}
62+
return tmp;
63+
}
64+
4665
// ============================================================================
4766
// Factory functions — all overloaded, no _f32 suffixes
4867
// ============================================================================
@@ -61,6 +80,11 @@ inline Point<float> point(const py::array_t<float>& arr) {
6180
const float* p = static_cast<const float*>(buf.ptr);
6281
return Point<float>(buf.size > 0 ? p[0] : 0, buf.size > 1 ? p[1] : 0);
6382
}
83+
// auto-double: accept any dtype (int, unknown), always produce Point<double>
84+
inline Point<double> point(const py::array& arr) {
85+
auto tmp = array_to_double_vec(arr);
86+
return Point<double>(tmp[0], tmp[1]);
87+
}
6488

6589
// -- LineString --
6690
inline LineString<double> linestring(const py::array_t<double>& arr) {
@@ -71,6 +95,11 @@ inline LineString<float> linestring(const py::array_t<float>& arr) {
7195
auto buf = arr.request();
7296
return LineString<float>(static_cast<const float*>(buf.ptr), buf.shape[0], buf.shape[1]);
7397
}
98+
inline LineString<double> linestring(const py::array& arr) {
99+
py::ssize_t n = arr.request().shape[0];
100+
auto tmp = array_to_double_vec(arr);
101+
return LineString<double>(tmp.data(), static_cast<size_t>(n), 2);
102+
}
74103

75104
// -- Polygon --
76105
inline Polygon<double> polygon(const py::array_t<double>& arr) {
@@ -81,12 +110,22 @@ inline Polygon<float> polygon(const py::array_t<float>& arr) {
81110
auto buf = arr.request();
82111
return Polygon<float>(static_cast<const float*>(buf.ptr), buf.shape[0], buf.shape[1]);
83112
}
113+
inline Polygon<double> polygon(const py::array& arr) {
114+
py::ssize_t n = arr.request().shape[0];
115+
auto tmp = array_to_double_vec(arr);
116+
return Polygon<double>(tmp.data(), static_cast<size_t>(n), 2);
117+
}
84118

85119
// -- LinearRing (double only for now) --
86120
inline LinearRing<double> linearring(const py::array_t<double>& arr) {
87121
auto buf = arr.request();
88122
return LinearRing<double>(static_cast<const double*>(buf.ptr), buf.shape[0], buf.shape[1]);
89123
}
124+
inline LinearRing<double> linearring(const py::array& arr) {
125+
py::ssize_t n = arr.request().shape[0];
126+
auto tmp = array_to_double_vec(arr);
127+
return LinearRing<double>(tmp.data(), static_cast<size_t>(n), 2);
128+
}
90129

91130
// -- MultiPoint: single array of shape (n_pts, 2) --
92131
inline MultiPoint<double> multipoint(const py::array_t<double>& arr) {
@@ -97,6 +136,11 @@ inline MultiPoint<float> multipoint(const py::array_t<float>& arr) {
97136
auto buf = arr.request();
98137
return MultiPoint<float>(static_cast<const float*>(buf.ptr), buf.shape[0], buf.shape[1]);
99138
}
139+
inline MultiPoint<double> multipoint(const py::array& arr) {
140+
py::ssize_t n = arr.request().shape[0];
141+
auto tmp = array_to_double_vec(arr);
142+
return MultiPoint<double>(tmp.data(), static_cast<size_t>(n), 2);
143+
}
100144

101145
// -- MultiLineString: vector of arrays, each (n_rows, 2); dtype distinguishes overload --
102146
inline MultiLineString<double> multilinestring(const std::vector<py::array_t<double>>& arrays) {
@@ -115,6 +159,15 @@ inline MultiLineString<float> multilinestring(const std::vector<py::array_t<floa
115159
}
116160
return mls;
117161
}
162+
inline MultiLineString<double> multilinestring(const std::vector<py::array>& arrays) {
163+
MultiLineString<double> mls;
164+
for (auto& arr : arrays) {
165+
py::ssize_t n = arr.request().shape[0];
166+
auto tmp = array_to_double_vec(arr);
167+
mls.add_line(tmp.data(), static_cast<size_t>(n), 2);
168+
}
169+
return mls;
170+
}
118171

119172
// -- MultiPolygon: vector of arrays, each (n_rows, 2); dtype distinguishes overload --
120173
inline MultiPolygon<double> multipolygon(const std::vector<py::array_t<double>>& arrays) {
@@ -133,6 +186,15 @@ inline MultiPolygon<float> multipolygon(const std::vector<py::array_t<float>>& a
133186
}
134187
return mp;
135188
}
189+
inline MultiPolygon<double> multipolygon(const std::vector<py::array>& arrays) {
190+
MultiPolygon<double> mp;
191+
for (auto& arr : arrays) {
192+
py::ssize_t n = arr.request().shape[0];
193+
auto tmp = array_to_double_vec(arr);
194+
mp.add_polygon(tmp.data(), static_cast<size_t>(n), 2);
195+
}
196+
return mp;
197+
}
136198

137199
// ============================================================================
138200
// Cross-type distance

tests/module.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ PYBIND11_MODULE(shapelycpp, m) {
3434
m.def("point", py::overload_cast<double, double>(&point), py::arg("x"), py::arg("y"));
3535
m.def("point", py::overload_cast<const py::array_t<double>&>(&point), py::arg("coords"));
3636
m.def("point", py::overload_cast<const py::array_t<float>&>(&point), py::arg("coords"));
37+
m.def("point", py::overload_cast<const py::array&>(&point), py::arg("coords")); // auto-double
3738

3839
// linestring / polygon / linearring: array dtype selects f64 or f32
3940
m.def("linestring", py::overload_cast<const py::array_t<double>&>(&linestring), py::arg("coords"));
4041
m.def("linestring", py::overload_cast<const py::array_t<float>&>(&linestring), py::arg("coords"));
42+
m.def("linestring", py::overload_cast<const py::array&>(&linestring), py::arg("coords")); // auto-double
4143
m.def("polygon", py::overload_cast<const py::array_t<double>&>(&polygon), py::arg("coords"));
4244
m.def("polygon", py::overload_cast<const py::array_t<float>&>(&polygon), py::arg("coords"));
45+
m.def("polygon", py::overload_cast<const py::array&>(&polygon), py::arg("coords")); // auto-double
4346
m.def("linearring", py::overload_cast<const py::array_t<double>&>(&linearring), py::arg("coords"));
47+
m.def("linearring", py::overload_cast<const py::array&>(&linearring), py::arg("coords")); // auto-double
4448

4549
// ======================================================================
4650
// Point<double> — full API
@@ -177,10 +181,13 @@ PYBIND11_MODULE(shapelycpp, m) {
177181
// -- Multi* factories (multipoint overloads by array dtype; multilinestring/multipolygon by vector<array_t<T>>) --
178182
m.def("multipoint", py::overload_cast<const py::array_t<double>&>(&multipoint), py::arg("coords"));
179183
m.def("multipoint", py::overload_cast<const py::array_t<float>&>(&multipoint), py::arg("coords"));
184+
m.def("multipoint", py::overload_cast<const py::array&>(&multipoint), py::arg("coords")); // auto-double
180185
m.def("multilinestring", py::overload_cast<const std::vector<py::array_t<double>>&>(&multilinestring), py::arg("lines"));
181186
m.def("multilinestring", py::overload_cast<const std::vector<py::array_t<float>>&>(&multilinestring), py::arg("lines"));
187+
m.def("multilinestring", py::overload_cast<const std::vector<py::array>&>(&multilinestring), py::arg("lines")); // auto-double
182188
m.def("multipolygon", py::overload_cast<const std::vector<py::array_t<double>>&>(&multipolygon), py::arg("polygons"));
183189
m.def("multipolygon", py::overload_cast<const std::vector<py::array_t<float>>&>(&multipolygon), py::arg("polygons"));
190+
m.def("multipolygon", py::overload_cast<const std::vector<py::array>&>(&multipolygon), py::arg("polygons")); // auto-double
184191

185192
// ======================================================================
186193
// MultiPoint<double> — full API

0 commit comments

Comments
 (0)