Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions atompack-py/src/database_flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ pub(super) fn get_molecules_flat_soa_impl<'py>(
i, sec.key
)));
}
if sec.type_tag != schema_entry.type_tag {
// Same key + kind but different type tag would silently
// reinterpret the bytes downstream as a different dtype.
return Err(invalid_data(format!(
"Section '{}' type tag mismatch at molecule {}: schema is {}, got {}",
sec.key, i, schema_entry.type_tag, sec.type_tag
)));
}
if schema_entry.per_atom {
let expected = n.checked_mul(schema_entry.elem_bytes).ok_or_else(|| {
invalid_data(format!("Section '{}' payload length overflow", sec.key))
Expand All @@ -199,6 +207,19 @@ pub(super) fn get_molecules_flat_soa_impl<'py>(
expected
)));
}
} else if schema_entry.slot_bytes != 0
&& sec.payload.len() != schema_entry.slot_bytes
{
// Per-molecule slot: schema is set from molecule 0; if a later
// molecule's payload disagrees, the memcpy below would OOB-write
// into an adjacent slot from a parallel rayon thread.
return Err(invalid_data(format!(
"Per-molecule section '{}' has invalid payload length {} for molecule {} (schema slot is {})",
sec.key,
sec.payload.len(),
i,
schema_entry.slot_bytes
)));
}

if schema_entry.slot_bytes == 0 {
Expand Down
55 changes: 55 additions & 0 deletions atompack-py/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,58 @@ def test_get_molecules_flat_empty(tmp_path: Path) -> None:
assert batch["n_atoms"].shape == (0,)
assert batch["positions"].shape == (0, 3)
assert batch["atomic_numbers"].shape == (0,)


def test_overlong_property_key_rejected_at_write(tmp_path: Path) -> None:
# Section keys are stored with a u8 length field; >255-byte keys would
# silently truncate before this fix, producing an unreadable record.
# The encoder now errors clearly instead.
mol = _make_molecule(-1.0)
mol.set_property("x" * 256, 1.0)

db = atompack.Database(str(tmp_path / "overlong.atp"))
with pytest.raises(ValueError, match=r"max is 255|too long|255 bytes"):
db.add_molecule(mol)


def test_get_molecules_flat_rejects_per_mol_slot_mismatch(tmp_path: Path) -> None:
# get_molecules_flat derives slot_bytes from the first molecule's section
# payload. A later molecule with the same key but different payload length
# would have OOB-written into an adjacent buffer slot from a parallel
# rayon thread before this fix; the new length check catches it cleanly.
mol1 = _make_molecule(-1.0)
mol1.set_property("vec", np.array([1.0, 2.0], dtype=np.float64)) # 16 bytes

mol2 = _make_molecule(-2.0)
mol2.set_property("vec", np.array([1.0, 2.0, 3.0], dtype=np.float64)) # 24 bytes

path = tmp_path / "uneven.atp"
db = atompack.Database(str(path))
db.add_molecule(mol1)
db.add_molecule(mol2)
db.flush()

db_r = atompack.Database.open(str(path))
with pytest.raises(ValueError, match=r"invalid payload length|schema slot"):
db_r.get_molecules_flat([0, 1])


def test_get_molecules_flat_rejects_per_mol_type_tag_mismatch(tmp_path: Path) -> None:
# If the same key shows up with a different dtype across molecules, the
# bytes would have been reinterpreted as the schema's dtype downstream.
# The new type-tag check catches this before the memcpy.
mol1 = _make_molecule(-1.0)
mol1.set_property("scalar", 1.5) # TYPE_FLOAT (f64)

mol2 = _make_molecule(-2.0)
mol2.set_property("scalar", 7) # TYPE_INT (i64), same 8-byte size

path = tmp_path / "tagmix.atp"
db = atompack.Database(str(path))
db.add_molecule(mol1)
db.add_molecule(mol2)
db.flush()

db_r = atompack.Database.open(str(path))
with pytest.raises(ValueError, match=r"type tag mismatch"):
db_r.get_molecules_flat([0, 1])
21 changes: 17 additions & 4 deletions atompack/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ impl SharedMmapBytes {
}
}

/// Convert a record byte length to the on-disk `u32` index field, erroring
/// instead of silently truncating when the record exceeds 4 GiB.
fn record_size_u32(len: usize, what: &str) -> Result<u32> {
len.try_into().map_err(|_| {
Error::InvalidData(format!(
"{} size {} bytes exceeds u32::MAX ({}); index field cannot represent it",
what,
len,
u32::MAX
))
})
}

pub struct AtomDatabase {
path: PathBuf,
compression: CompressionType,
Expand Down Expand Up @@ -344,7 +357,7 @@ impl AtomDatabase {
let compressed_records: Vec<(Vec<u8>, u32, u32)> = records
.par_iter()
.map(|(bytes, num_atoms)| {
let uncompressed_size = bytes.len() as u32;
let uncompressed_size = record_size_u32(bytes.len(), "uncompressed record")?;
let compressed = compress(bytes, compression)?;
Ok((compressed, uncompressed_size, *num_atoms))
})
Expand All @@ -361,7 +374,7 @@ impl AtomDatabase {

new_indices.push(MoleculeIndex {
offset,
compressed_size: compressed_data.len() as u32,
compressed_size: record_size_u32(compressed_data.len(), "compressed record")?,
uncompressed_size,
num_atoms,
});
Expand Down Expand Up @@ -390,7 +403,7 @@ impl AtomDatabase {
let compressed_records: Vec<(Vec<u8>, u32, u32)> = records
.into_par_iter()
.map(|(bytes, num_atoms)| {
let uncompressed_size = bytes.len() as u32;
let uncompressed_size = record_size_u32(bytes.len(), "uncompressed record")?;
let compressed = compress(&bytes, compression)?;
Ok((compressed, uncompressed_size, num_atoms))
})
Expand All @@ -406,7 +419,7 @@ impl AtomDatabase {

new_indices.push(MoleculeIndex {
offset,
compressed_size: compressed_data.len() as u32,
compressed_size: record_size_u32(compressed_data.len(), "compressed record")?,
uncompressed_size,
num_atoms,
});
Expand Down
65 changes: 48 additions & 17 deletions atompack/src/storage/soa.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
use super::*;

/// Write a single tagged section: [kind:u8][key_len:u8][key][type_tag:u8][payload_len:u32][payload]
fn write_section(buf: &mut Vec<u8>, kind: u8, key: &str, type_tag: u8, payload: &[u8]) {
///
/// Errors instead of silently truncating when the key exceeds the 255-byte
/// `key_len: u8` field or the payload exceeds the `u32` length field.
fn write_section(
buf: &mut Vec<u8>,
kind: u8,
key: &str,
type_tag: u8,
payload: &[u8],
) -> Result<()> {
let key_len: u8 = key.len().try_into().map_err(|_| {
Error::InvalidData(format!(
"Section key '{}...' is {} bytes; max is 255",
&key[..32.min(key.len())],
key.len()
))
})?;
let payload_len: u32 = payload.len().try_into().map_err(|_| {
Error::InvalidData(format!(
"Section '{}' payload is {} bytes; max is u32::MAX",
key,
payload.len()
))
})?;
buf.push(kind);
buf.push(key.len() as u8);
buf.push(key_len);
buf.extend_from_slice(key.as_bytes());
buf.push(type_tag);
buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
buf.extend_from_slice(&payload_len.to_le_bytes());
buf.extend_from_slice(payload);
Ok(())
}

fn property_value_type_tag(value: &PropertyValue) -> u8 {
Expand Down Expand Up @@ -236,7 +260,7 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
}
buf.extend_from_slice(&molecule.atomic_numbers);

let mut n_sections: u16 = 0;
let mut n_sections: usize = 0;
if molecule.charges.is_some() {
n_sections += 1;
}
Expand All @@ -261,21 +285,28 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
if molecule.velocities.is_some() {
n_sections += 1;
}
n_sections += molecule.atom_properties.len() as u16;
n_sections += molecule.properties.len() as u16;
buf.extend_from_slice(&n_sections.to_le_bytes());
n_sections += molecule.atom_properties.len();
n_sections += molecule.properties.len();
let n_sections_u16: u16 = n_sections.try_into().map_err(|_| {
Error::InvalidData(format!(
"Molecule has {} sections; on-disk format limit is {}",
n_sections,
u16::MAX
))
})?;
buf.extend_from_slice(&n_sections_u16.to_le_bytes());

if let Some(ref charges) = molecule.charges {
let mut payload = Vec::with_capacity(charges.len() * 8);
extend_f64(&mut payload, charges);
write_section(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, &payload);
write_section(&mut buf, KIND_BUILTIN, "charges", TYPE_F64_ARRAY, &payload)?;
}
if let Some(ref cell) = molecule.cell {
let mut payload = Vec::with_capacity(72);
for row in cell {
extend_f64(&mut payload, row);
}
write_section(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, &payload);
write_section(&mut buf, KIND_BUILTIN, "cell", TYPE_MAT3X3_F64, &payload)?;
}
if let Some(energy) = molecule.energy {
write_section(
Expand All @@ -284,28 +315,28 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
"energy",
TYPE_FLOAT,
&energy.to_le_bytes(),
);
)?;
}
if let Some(ref forces) = molecule.forces {
let mut payload = Vec::with_capacity(forces.len() * 12);
for f in forces {
extend_f32(&mut payload, f);
}
write_section(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, &payload);
write_section(&mut buf, KIND_BUILTIN, "forces", TYPE_VEC3_F32, &payload)?;
}
if let Some(ref name) = molecule.name {
write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes());
write_section(&mut buf, KIND_BUILTIN, "name", TYPE_STRING, name.as_bytes())?;
}
if let Some(ref pbc) = molecule.pbc {
let payload = [pbc[0] as u8, pbc[1] as u8, pbc[2] as u8];
write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload);
write_section(&mut buf, KIND_BUILTIN, "pbc", TYPE_BOOL3, &payload)?;
}
if let Some(ref stress) = molecule.stress {
let mut payload = Vec::with_capacity(72);
for row in stress {
extend_f64(&mut payload, row);
}
write_section(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, &payload);
write_section(&mut buf, KIND_BUILTIN, "stress", TYPE_MAT3X3_F64, &payload)?;
}
if let Some(ref velocities) = molecule.velocities {
let mut payload = Vec::with_capacity(velocities.len() * 12);
Expand All @@ -318,7 +349,7 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
"velocities",
TYPE_VEC3_F32,
&payload,
);
)?;
}

let mut atom_keys: Vec<&String> = molecule.atom_properties.keys().collect();
Expand All @@ -332,7 +363,7 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
key,
property_value_type_tag(value),
&payload,
);
)?;
}

let mut prop_keys: Vec<&String> = molecule.properties.keys().collect();
Expand All @@ -346,7 +377,7 @@ pub(super) fn serialize_molecule_soa(molecule: &Molecule) -> Result<Vec<u8>> {
key,
property_value_type_tag(value),
&payload,
);
)?;
}

Ok(buf)
Expand Down
Loading