Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ALTER TABLE nvlink_partitions
ADD COLUMN nmx_c_partition_id INTEGER;

ALTER TABLE nvlink_partitions
DROP CONSTRAINT IF EXISTS nvlink_partitions_nmx_m_id_key;

ALTER TABLE nvlink_partitions
ALTER COLUMN nmx_m_id DROP NOT NULL;

ALTER TABLE nvlink_partitions
ADD CONSTRAINT nvlink_partitions_external_id_check
CHECK (
nmx_m_id IS NOT NULL
OR nmx_c_partition_id IS NOT NULL
);
26 changes: 25 additions & 1 deletion crates/api-db/src/nvl_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

use carbide_uuid::nvlink::NvLinkPartitionId;
use carbide_uuid::nvlink::{NvLinkDomainId, NvLinkLogicalPartitionId};
use model::nvl_partition::{NewNvlPartition, NvlPartition, NvlPartitionSnapshotPgJson};
use sqlx::PgConnection;

Expand All @@ -42,15 +43,17 @@ pub async fn create(
let query = "INSERT INTO nvlink_partitions (
id,
nmx_m_id,
nmx_c_partition_id,
name,
domain_uuid,
logical_partition_id)
VALUES ($1, $2, $3, $4, $5)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING row_to_json(nvlink_partitions.*)";

let partition: NvlPartitionSnapshotPgJson = sqlx::query_as(query)
.bind(value.id)
.bind(&value.nmx_m_id)
.bind(value.nmx_c_partition_id)
.bind(value.name.as_str())
.bind(value.domain_uuid)
.bind(value.logical_partition_id)
Expand Down Expand Up @@ -169,3 +172,24 @@ pub async fn final_delete(

Ok(partition)
}

pub async fn final_delete_nmx_m_only_for_logical_partition_and_domain(
logical_partition_id: NvLinkLogicalPartitionId,
domain_uuid: NvLinkDomainId,
txn: &mut PgConnection,
) -> Result<Vec<NvLinkPartitionId>, DatabaseError> {
let query = "DELETE FROM nvlink_partitions
WHERE logical_partition_id = $1
AND domain_uuid = $2
AND nmx_m_id IS NOT NULL
AND nmx_c_partition_id IS NULL
RETURNING id";
let partition_ids: Vec<NvLinkPartitionId> = sqlx::query_as(query)
.bind(logical_partition_id)
.bind(domain_uuid)
.fetch_all(txn)
.await
.map_err(|e| DatabaseError::new(query, e))?;

Ok(partition_ids)
}
11 changes: 8 additions & 3 deletions crates/api-model/src/nvl_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ pub struct NewNvlPartition {
pub name: NvlPartitionName,
pub logical_partition_id: NvLinkLogicalPartitionId,
pub domain_uuid: NvLinkDomainId,
pub nmx_m_id: String,
pub nmx_m_id: Option<String>,
pub nmx_c_partition_id: Option<i32>,
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -68,7 +69,8 @@ impl From<NvlPartitionName> for String {
#[derive(Debug, Clone)]
pub struct NvlPartition {
pub id: NvLinkPartitionId,
pub nmx_m_id: String,
pub nmx_m_id: Option<String>,
pub nmx_c_partition_id: Option<i32>,
pub domain_uuid: NvLinkDomainId,
pub name: NvlPartitionName,
pub created: DateTime<Utc>,
Expand All @@ -85,7 +87,9 @@ pub fn is_marked_as_deleted(partition: &NvlPartition) -> bool {
#[derive(Debug, Serialize, Deserialize)]
pub struct NvlPartitionSnapshotPgJson {
pub id: NvLinkPartitionId,
pub nmx_m_id: String,
pub nmx_m_id: Option<String>,
#[serde(default)]
pub nmx_c_partition_id: Option<i32>,
pub name: NvlPartitionName,
pub domain_uuid: NvLinkDomainId,
pub created: DateTime<Utc>,
Expand All @@ -100,6 +104,7 @@ impl TryFrom<NvlPartitionSnapshotPgJson> for NvlPartition {
Ok(Self {
id: value.id,
nmx_m_id: value.nmx_m_id,
nmx_c_partition_id: value.nmx_c_partition_id,
domain_uuid: value.domain_uuid,
name: value.name,
created: value.created,
Expand Down
110 changes: 110 additions & 0 deletions crates/api/src/tests/nvl_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,116 @@ async fn test_create_instance_gpu_in_unknown_partition(pool: sqlx::PgPool) {
assert_eq!(gpu_uid_count, 4);
}

#[crate::sqlx_test]
async fn test_nmx_c_monitor_deletes_legacy_nmx_m_partition_row(pool: sqlx::PgPool) {
let mut config = common::api_fixtures::get_config();
if let Some(nvlink_config) = config.nvlink_config.as_mut() {
nvlink_config.enabled = true;
}

let mut test_overrides = TestEnvOverrides::with_config(config);
test_overrides.nmxc_unknown_partition = Some(true);
let env =
common::api_fixtures::create_test_env_with_overrides(pool.clone(), test_overrides).await;
let segment_id = env.create_vpc_and_tenant_segment().await;

let NvlLogicalPartitionFixture {
id: logical_partition_id,
logical_partition: _logical_partition,
} = create_nvl_logical_partition(&env, "test_partition".to_string()).await;

let mh1 = create_managed_host_with_hardware_info_template(
&env,
HardwareInfoTemplate::Custom(
crate::tests::common::api_fixtures::host::GB200_COMPUTE_TRAY_1_INFO_JSON,
),
)
.await;
let machine1 = mh1.host().rpc_machine().await;
assert_eq!(&machine1.state, "Ready");
let domain_uuid = machine1
.nvlink_info
.as_ref()
.and_then(|info| info.domain_uuid)
.unwrap();

let legacy_partition_id = carbide_uuid::nvlink::NvLinkPartitionId::new();
let legacy_nmx_m_id = "699c4bdb83acac93e9c1476f";
sqlx::query(
r#"
INSERT INTO nvlink_partitions
(id, nmx_m_id, name, domain_uuid, logical_partition_id)
VALUES
($1, $2, $3, $4, $5)
"#,
)
.bind(legacy_partition_id)
.bind(legacy_nmx_m_id)
.bind("legacy-nmxm-partition")
.bind(domain_uuid)
.bind(logical_partition_id)
.execute(&pool)
.await
.unwrap();

let discovery_info1 = machine1.discovery_info.as_ref().unwrap();
assert_eq!(discovery_info1.gpus.len(), 4);

let nvl_config = rpc::forge::InstanceNvLinkConfig {
gpu_configs: discovery_info1
.gpus
.iter()
.filter_map(|gpu| {
gpu.platform_info.as_ref().map(|platform_info| {
rpc::forge::InstanceNvLinkGpuConfig {
device_instance: platform_info.module_id - 1,
logical_partition_id: Some(logical_partition_id),
}
})
})
.collect(),
};

let (_tinstance, _instance) =
create_instance_with_nvlink_config(&env, &mh1, nvl_config, segment_id).await;
env.run_nvl_partition_monitor_iteration().await;

let legacy_row_count: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*)
FROM nvlink_partitions
WHERE id = $1
AND nmx_m_id = $2
AND nmx_c_partition_id IS NULL
"#,
)
.bind(legacy_partition_id)
.bind(legacy_nmx_m_id)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(legacy_row_count, 0);

let replacement_rows: Vec<(Option<String>, Option<i32>)> = sqlx::query_as(
r#"
SELECT nmx_m_id, nmx_c_partition_id
FROM nvlink_partitions
WHERE logical_partition_id = $1
AND domain_uuid = $2
"#,
)
.bind(logical_partition_id)
.bind(domain_uuid)
.fetch_all(&pool)
.await
.unwrap();

assert_eq!(replacement_rows.len(), 1);
let (nmx_m_id, nmx_c_partition_id) = &replacement_rows[0];
assert!(nmx_m_id.is_none());
assert!(nmx_c_partition_id.is_some());
}

// `*_use_nmxc_simulator` integration tests only run when environment variable RUN_NMXC_SIMULATOR_TESTS is set (any value).
// Before running these tests, need to have nmx_simulator running on port 9601.
// Ex: "sudo ./install_simulators.sh -p 9601 -n 1 -g nmx-c-nvlink_2.0.0_2025-04-23_01-10_internal.tar.gz -i 127.0.0.0 -m enabled -t gb200_nvl36r1_c2g4_topology -d true"
Expand Down
111 changes: 93 additions & 18 deletions crates/nvlink-manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ impl PartitionProcessingContext {
.collect();
let db_nvl_partitions = db_nvl_partitions
.into_iter()
.filter_map(|p| p.nmx_m_id.parse::<u32>().ok().map(|id| (id, p)))
.filter_map(|p| {
p.nmx_c_partition_id
.and_then(|id| u32::try_from(id).ok())
.map(|id| (id, p))
})
.collect();
Self {
nmx_c_partitions,
Expand Down Expand Up @@ -546,7 +550,15 @@ impl PartitionProcessingContext {
};

// Get the GPU IDs that are already in the partition, plus the GPU being added.
let nmx_c_partition_id = partition.nmx_m_id.parse::<u32>().unwrap();
let Some(nmx_c_partition_id) = partition
.nmx_c_partition_id
.and_then(|id| u32::try_from(id).ok())
else {
return Err(NvLinkManagerError::internal(format!(
"NMX-C partition ID is required for DB partition {}",
partition.id
)));
};
let gpu_uids: Vec<u64> = if let Some(nmx_c_partition) =
self.nmx_c_partitions
.get(&libnmxc::nmxc_model::PartitionId {
Expand Down Expand Up @@ -1656,29 +1668,56 @@ impl NvlPartitionMonitor {
for operation in operations {
match operation.operation_type {
NmxcPartitionOperationType::Create => {
let matching_partition = match nmx_c_partitions.values().find(|p| {
let p_uids: HashSet<u64> = p.gpu_uid_list.iter().copied().collect();
let op_uids: HashSet<u64> =
operation.gpu_uids.iter().copied().collect();
p_uids == op_uids
}) {
Some(p) => p,
None => {
tracing::error!(
"NMX-C partition not found for name {}",
operation.name
);
continue;
}
};
let Some(nmx_c_partition_id) = matching_partition
.partition_id
.as_ref()
.map(|id| id.partition_id)
else {
tracing::error!(
"NMX-C partition ID not found for name {}",
operation.name
);
continue;
};
let Ok(nmx_c_partition_id) = i32::try_from(nmx_c_partition_id) else {
tracing::error!(
"NMX-C partition ID does not fit in database column for name {}",
operation.name
);
continue;
};

// Create the nvl partition in the database
let new_partition = model::nvl_partition::NewNvlPartition {
id: NvLinkPartitionId::new(),
logical_partition_id,
name: NvlPartitionName::try_from(operation.name.clone())?,
domain_uuid: operation.domain_uuid.unwrap_or_default(),
nmx_m_id: match nmx_c_partitions.values().find(|p| {
let p_uids: HashSet<u64> = p.gpu_uid_list.iter().copied().collect();
let op_uids: HashSet<u64> =
operation.gpu_uids.iter().copied().collect();
p_uids == op_uids
}) {
Some(p) => nmx_c_partition_id_string(p),
None => {
tracing::error!(
"NMX-C partition not found for name {}",
operation.name
);
continue;
}
},
nmx_m_id: None,
nmx_c_partition_id: Some(nmx_c_partition_id),
};
let _partition = db::nvl_partition::create(&new_partition, txn).await?;
self.delete_legacy_nmx_m_partition_rows(
txn,
logical_partition_id,
operation.domain_uuid,
)
.await?;
}
NmxcPartitionOperationType::Remove(_) => {
db::nvl_partition::final_delete(
Expand All @@ -1688,8 +1727,14 @@ impl NvlPartitionMonitor {
.await?;
}
NmxcPartitionOperationType::Update(_) => {
// No-op, since partition membership is not tracked in the partitions table. The status observation of the
// Partition membership is not tracked in the partitions table. The status observation of the
// added/removed GPUs will be updated.
self.delete_legacy_nmx_m_partition_rows(
txn,
logical_partition_id,
operation.domain_uuid,
)
.await?;
}
NmxcPartitionOperationType::RemoveUnknownPartition(_) => {
// No-op, since default partition membership is not tracked in the partitions table. The status observation of the
Expand All @@ -1710,6 +1755,36 @@ impl NvlPartitionMonitor {
Ok(())
}

async fn delete_legacy_nmx_m_partition_rows(
&self,
txn: &mut sqlx::Transaction<'_, sqlx::Postgres>,
logical_partition_id: NvLinkLogicalPartitionId,
domain_uuid: Option<NvLinkDomainId>,
) -> NvLinkManagerResult<()> {
let Some(domain_uuid) = domain_uuid else {
return Ok(());
};

let deleted_ids =
db::nvl_partition::final_delete_nmx_m_only_for_logical_partition_and_domain(
logical_partition_id,
domain_uuid,
txn,
)
.await?;

if !deleted_ids.is_empty() {
tracing::info!(
logical_partition_id = %logical_partition_id,
domain_uuid = %domain_uuid,
deleted_count = deleted_ids.len(),
"Deleted legacy NMX-M physical partition rows after NMX-C reconciliation"
);
}

Ok(())
}

async fn load_mnnvl_managed_host_snapshots(
&self,
txn: &mut sqlx::Transaction<'_, sqlx::Postgres>,
Expand Down
2 changes: 1 addition & 1 deletion crates/rpc/src/model/nvl_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl TryFrom<NvlPartition> for rpc_forge::NvLinkPartition {
Ok(rpc_forge::NvLinkPartition {
id: Some(src.id),
name: src.name.clone().into(),
nmx_m_id: src.nmx_m_id,
nmx_m_id: src.nmx_m_id.unwrap_or_default(),
domain_uuid: Some(src.domain_uuid),
logical_partition_id: src.logical_partition_id,
})
Expand Down
Loading