Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 166 additions & 40 deletions crates/contract/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pub mod tee;
pub mod update;
#[cfg(feature = "dev-utils")]
pub mod utils;

pub mod v3_8_1_state;

#[cfg(feature = "bench-contract-methods")]
Expand Down Expand Up @@ -68,9 +67,9 @@ use primitives::{
key_state::{AuthenticatedAccountId, AuthenticatedParticipantId, EpochId, KeyEventId, Keyset},
signature::{SignRequest, SignRequestArgs, SignatureRequest, YieldIndex},
thresholds::{Threshold, ThresholdParameters},
votes::types::{ProposalHash, ProposalId},
};
use tee::measurements::{ContractExpectedMeasurements, MeasurementVoteAction, MeasurementVotes};
use tee::proposal::{CodeHashesVotes, LauncherHashVotes};
use tee::measurements::{ContractExpectedMeasurements, MeasurementVoteAction};

use state::{running::RunningContractState, ProtocolContractState};
use tee::{
Expand Down Expand Up @@ -1386,14 +1385,26 @@ impl MpcContract {
};

let participant = AuthenticatedParticipantId::new(threshold_parameters.participants())?;
let votes = self.tee_state.vote(code_hash, &participant);
let votes = self.tee_state.vote(code_hash, participant);

let tee_upgrade_deadline_duration =
Duration::from_secs(self.config.tee_upgrade_deadline_duration_seconds);

let num_votes = votes.count_for(|authenticated_participant_id| {
threshold_parameters
.participants()
.is_participant_given_participant_id(&authenticated_participant_id.get())
});

// If the vote threshold is met and the new Docker hash is allowed by the TEE's RTMR3,
// update the state
if votes >= self.threshold()?.value() {
if num_votes
>= self
.threshold()?
.value()
.try_into()
.expect("threshold must not exceed usize limit")
{
self.tee_state
.whitelist_tee_proposal(code_hash, tee_upgrade_deadline_duration);
}
Expand Down Expand Up @@ -1427,10 +1438,22 @@ impl MpcContract {
let action = LauncherVoteAction::Add(launcher_hash);
let votes = self.tee_state.vote_launcher(action, &participant);

let num_votes = votes.count_for(|authenticated_participant_id| {
threshold_parameters
.participants()
.is_participant_given_participant_id(&authenticated_participant_id.get())
});

let tee_upgrade_deadline_duration =
Duration::from_secs(self.config.tee_upgrade_deadline_duration_seconds);

if votes >= self.threshold()?.value() {
if num_votes
>= self
.threshold()?
.value()
.try_into()
.expect("expect cast from u64 to usize for threshold to succeed")
{
let added = self
.tee_state
.add_launcher_image(launcher_hash, tee_upgrade_deadline_duration);
Expand Down Expand Up @@ -1464,10 +1487,15 @@ impl MpcContract {
let participant = AuthenticatedParticipantId::new(threshold_parameters.participants())?;
let action = LauncherVoteAction::Remove(launcher_hash);
let votes = self.tee_state.vote_launcher(action, &participant);
let num_votes = votes.count_for(|authenticated_participant_id| {
threshold_parameters
.participants()
.is_participant_given_participant_id(&authenticated_participant_id.get())
});

// Removal requires ALL participants to vote
let total_participants = threshold_parameters.participants().len() as u64;
if votes >= total_participants {
let total_participants = threshold_parameters.participants().len();
if num_votes >= total_participants {
let removed = self.tee_state.remove_launcher_image(&launcher_hash);
log!("launcher hash remove result: {}", removed);
}
Expand Down Expand Up @@ -1497,9 +1525,21 @@ impl MpcContract {

let participant = AuthenticatedParticipantId::new(threshold_parameters.participants())?;
let action = MeasurementVoteAction::Add(measurement.clone());
let votes = self.tee_state.vote_measurement(action, &participant);
let votes = self.tee_state.vote_measurement(action, participant);

let num_votes = votes.count_for(|authenticated_participant_id| {
threshold_parameters
.participants()
.is_participant_given_participant_id(&authenticated_participant_id.get())
});

if votes >= self.threshold()?.value() {
if num_votes
>= self
.threshold()?
.value()
.try_into()
.expect("converting threshold to usize must succeed")
{
let added = self.tee_state.add_measurement(measurement);
log!("OS measurement add result: {}", added);
}
Expand Down Expand Up @@ -1530,11 +1570,16 @@ impl MpcContract {

let participant = AuthenticatedParticipantId::new(threshold_parameters.participants())?;
let action = MeasurementVoteAction::Remove(measurement.clone());
let votes = self.tee_state.vote_measurement(action, &participant);
let votes = self.tee_state.vote_measurement(action, participant);

let num_votes = votes.count_for(|authenticated_participant_id| {
threshold_parameters
.participants()
.is_participant_given_participant_id(&authenticated_participant_id.get())
});
// Removal requires ALL participants to vote
let total_participants = threshold_parameters.participants().len() as u64;
if votes >= total_participants {
let total_participants = threshold_parameters.participants().len();
if num_votes >= total_participants {
let removed = self.tee_state.remove_measurement(&measurement);
log!("OS measurement remove result: {}", removed);
}
Expand All @@ -1543,9 +1588,17 @@ impl MpcContract {
}

/// Returns the current OS measurement votes, showing each participant's vote.
pub fn os_measurement_votes(&self) -> MeasurementVotes {
pub fn os_measurement_votes(
&self,
) -> BTreeMap<
ProposalId,
(
(ProposalHash, MeasurementVoteAction),
BTreeSet<AuthenticatedParticipantId>,
),
> {
log!("os_measurement_votes");
self.tee_state.measurement_votes.clone()
self.tee_state.measurement_votes.snapshot()
}

/// Returns all currently allowed OS measurements.
Expand Down Expand Up @@ -1897,13 +1950,29 @@ impl MpcContract {
}

/// Returns the current launcher hash votes, showing each participant's vote.
pub fn launcher_hash_votes(&self) -> LauncherHashVotes {
self.tee_state.launcher_votes.clone()
pub fn launcher_hash_votes(
&self,
) -> BTreeMap<
ProposalId,
(
(ProposalHash, LauncherVoteAction),
BTreeSet<AuthenticatedParticipantId>,
),
> {
self.tee_state.launcher_votes.snapshot()
}

/// Returns the current code hash votes, showing each participant's vote.
pub fn code_hash_votes(&self) -> CodeHashesVotes {
self.tee_state.votes.clone()
pub fn code_hash_votes(
&self,
) -> BTreeMap<
ProposalId,
(
(ProposalHash, NodeImageHash),
BTreeSet<AuthenticatedParticipantId>,
),
> {
self.tee_state.votes.snapshot()
}

pub fn get_pending_request(&self, request: &SignatureRequest) -> Option<YieldIndex> {
Expand Down Expand Up @@ -2378,7 +2447,6 @@ mod tests {
};

use super::*;
use crate::errors::NodeMigrationError;
use crate::primitives::participants::{ParticipantId, ParticipantInfo};
use crate::primitives::test_utils::{
bogus_ed25519_near_public_key, bogus_ed25519_public_key, gen_account_id, gen_participant,
Expand All @@ -2400,6 +2468,7 @@ mod tests {
};
use crate::tee::proposal::{get_docker_compose_hash, LauncherVoteAction};
use crate::tee::tee_state::NodeId;
use crate::{errors::NodeMigrationError, primitives::votes::types::PROPOSAL_HASH_BYTES};
use assert_matches::assert_matches;
use dtos::{Attestation, Ed25519PublicKey, ForeignTxSignPayload, MockAttestation};
use elliptic_curve::Field as _;
Expand Down Expand Up @@ -5551,10 +5620,10 @@ mod tests {
let participant_list = participants.participants();
let launcher_hash = make_launcher_hash(0xCC);

assert!(contract.launcher_hash_votes().vote_by_account.is_empty());
assert!(contract.launcher_hash_votes().is_empty());

// First vote
let (account_0, _, _) = &participant_list[0];
let (account_0, p_id_0, _) = &participant_list[0];
testing_env!(VMContextBuilder::new()
.signer_account_id(account_0.clone())
.predecessor_account_id(account_0.clone())
Expand All @@ -5563,13 +5632,29 @@ mod tests {
.vote_add_launcher_hash(launcher_hash)
.expect("vote should succeed");

let votes = &contract.launcher_hash_votes().vote_by_account;
assert_eq!(votes.len(), 1);
let auth_p_id_0 = AuthenticatedParticipantId::new(&participants).unwrap();
assert_eq!(auth_p_id_0.get(), *p_id_0);

let expected_action = LauncherVoteAction::Add(launcher_hash);
assert!(votes.values().all(|v| *v == expected_action));
let expected_hash: [u8; PROPOSAL_HASH_BYTES] =
sha2::Sha256::digest(borsh::to_vec(&expected_action).unwrap()).into();
let expected_hash = ProposalHash::new(expected_hash);
{
let votes = contract.launcher_hash_votes();
assert_eq!(
votes,
BTreeMap::from([(
0.into(),
(
(expected_hash, expected_action.clone()),
BTreeSet::from([auth_p_id_0.clone()])
)
)])
);
}

// Second vote
let (account_1, _, _) = &participant_list[1];
let (account_1, p_id_1, _) = &participant_list[1];
testing_env!(VMContextBuilder::new()
.signer_account_id(account_1.clone())
.predecessor_account_id(account_1.clone())
Expand All @@ -5578,9 +5663,21 @@ mod tests {
.vote_add_launcher_hash(launcher_hash)
.expect("vote should succeed");

let votes = &contract.launcher_hash_votes().vote_by_account;
assert_eq!(votes.len(), 2);
assert!(votes.values().all(|v| *v == expected_action));
let auth_p_id_1 = AuthenticatedParticipantId::new(&participants).unwrap();
assert_eq!(auth_p_id_1.get(), *p_id_1);
{
let votes = contract.launcher_hash_votes();
assert_eq!(
votes,
BTreeMap::from([(
0.into(),
(
(expected_hash, expected_action.clone()),
BTreeSet::from([auth_p_id_0, auth_p_id_1])
)
)])
);
}

// Third vote reaches threshold — votes should be cleared
let (account_2, _, _) = &participant_list[2];
Expand All @@ -5593,7 +5690,7 @@ mod tests {
.expect("vote should succeed");

assert!(
contract.launcher_hash_votes().vote_by_account.is_empty(),
contract.launcher_hash_votes().is_empty(),
"votes should be cleared after threshold reached"
);
}
Expand All @@ -5610,8 +5707,12 @@ mod tests {
let participant_list = participants.participants();
let code_hash = NodeImageHash::from([0xAB; 32]);

assert!(contract.code_hash_votes().proposal_by_account.is_empty());
assert!(contract.code_hash_votes().is_empty());

let expected_hash: [u8; PROPOSAL_HASH_BYTES] =
sha2::Sha256::digest(borsh::to_vec(&code_hash).unwrap()).into();
let expected_hash = ProposalHash::new(expected_hash);
let mut expected_voter_set = BTreeSet::new();
for (i, (account, _, _)) in participant_list[..threshold as usize].iter().enumerate() {
testing_env!(VMContextBuilder::new()
.signer_account_id(account.clone())
Expand All @@ -5621,10 +5722,19 @@ mod tests {
.vote_code_hash(code_hash)
.expect("vote should succeed");

let votes = &contract.code_hash_votes().proposal_by_account;
let auth_p_id = AuthenticatedParticipantId::new(&participants).unwrap();
expected_voter_set.insert(auth_p_id.clone());
let votes = contract.code_hash_votes();

if i < (threshold - 1) as usize {
assert_eq!(votes.len(), i + 1);
assert!(votes.values().all(|v| *v == code_hash));
let expected = BTreeMap::from([(
0.into(),
(
(expected_hash, code_hash.clone()),
expected_voter_set.clone(),
),
)]);
assert_eq!(votes, expected);
} else {
assert!(
votes.is_empty(),
Expand Down Expand Up @@ -5940,10 +6050,10 @@ mod tests {
let measurement = make_measurement(0xCC);

// Initially empty
assert!(contract.os_measurement_votes().vote_by_account.is_empty());
assert!(contract.os_measurement_votes().is_empty());

// Cast one vote
let (account_id, _, _) = &participant_list[0];
let (account_id, p_id, _) = &participant_list[0];
testing_env!(VMContextBuilder::new()
.signer_account_id(account_id.clone())
.predecessor_account_id(account_id.clone())
Expand All @@ -5952,10 +6062,26 @@ mod tests {
.vote_add_os_measurement(measurement.clone())
.expect("add vote should succeed");

let votes = contract.os_measurement_votes();
assert_eq!(votes.vote_by_account.len(), 1);
let (_, action) = votes.vote_by_account.iter().next().unwrap();
assert_eq!(*action, MeasurementVoteAction::Add(measurement));
let expected_action = MeasurementVoteAction::Add(measurement);
let expected_hash: [u8; PROPOSAL_HASH_BYTES] =
sha2::Sha256::digest(borsh::to_vec(&expected_action).unwrap()).into();
let expected_hash = ProposalHash::new(expected_hash);

let auth_p_id = AuthenticatedParticipantId::new(&participants).unwrap();
assert_eq!(auth_p_id.get(), *p_id);
{
let votes = contract.os_measurement_votes();
assert_eq!(
votes,
BTreeMap::from([(
0.into(),
(
(expected_hash, expected_action),
BTreeSet::from([auth_p_id])
)
)])
);
}
}

/// Tests the allowed_os_measurements view method returns the full structs
Expand Down
2 changes: 1 addition & 1 deletion crates/contract/src/primitives/votes/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub trait VoterBounds: BorshSerialize + BorshDeserialize + Ord + Clone {}
impl<T: BorshSerialize + BorshDeserialize + Ord + Clone> ProposalBounds for T {}
impl<T: BorshSerialize + BorshDeserialize + Ord + Clone> VoterBounds for T {}

#[near(serializers=[borsh])]
#[near(serializers=[borsh, json])]
#[derive(Debug, PartialEq, PartialOrd, Eq, Ord, Clone, Copy, From, Deref, Into)]
pub struct ProposalId(pub(crate) u64);

Expand Down
Loading
Loading