Source code for kaplan_meier.data_owner

"""
Module implements the data owners Alice and Bob.
"""

from __future__ import annotations

import asyncio
from typing import Any, Optional, SupportsInt, Union, cast

import numpy as np
import numpy.typing as npt
import pandas as pd

from tno.mpc.communication import Pool
from tno.mpc.encryption_schemes.paillier import Paillier, PaillierCiphertext

from .player import Player


[docs]class DataOwner(Player): """ Data owner in the MPC protocol """
[docs] def __init__( self, data: pd.DataFrame, pool: Pool, *args: Any, **kwargs: Any, ) -> None: """ Initializes data owner :param data: the data to use for this data owner :param pool: a communication pool :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self._paillier_scheme: Optional[Paillier] = None self._data = data self.pool = pool
[docs] async def receive_message(self, party: str, msg_id: Optional[str] = None) -> Any: """ Receives a message from a party (belonging to an optional message identifier) :param party: the party to receive a message from :param msg_id: the message id :return: the received message """ return await self.pool.recv(party, msg_id=msg_id)
[docs] async def send_message( self, receiver: str, message: Any, msg_id: Optional[str] = None ) -> None: """ Sends a message to a party (with an optional message identifier) :param receiver: the party to send a message to :param message: the message to send :param msg_id: the message id """ await self.pool.send(receiver, message, msg_id=msg_id)
@property def records(self) -> int: """ Number of records in the loaded dataset :return: number of records """ return self.data.shape[0] # type: ignore[no-any-return] @property def groups(self) -> int: """ Number of groups in the loaded datasets :return: number of groups :raise NotImplementedError: raised when not implemented """ raise NotImplementedError() @property def data(self) -> Union[pd.DataFrame, npt.NDArray[np.int32]]: """ The loaded dataset :return: dataset :raise ValueError: raised when there is no data available """ if self._data is None: raise ValueError("No event data available yet.") return self._data @property def paillier_scheme(self) -> Paillier: """ The Paillier scheme :return: Paillier scheme :raise ValueError: raised when Paillier scheme is not available yet. """ if self._paillier_scheme is None: raise ValueError("There is no Paillier scheme available yet.") return self._paillier_scheme
[docs] def stop_randomness_generation(self) -> None: """ Stop generation of randomness. """ self.paillier_scheme.randomness.shut_down()
[docs] def encrypt(self, data: npt.NDArray[np.float64]) -> npt.NDArray[np.object_]: """ Method to encrypt a dataset using the initialized Paillier scheme :param data: the dataset to encrypt :return: an encrypted dataset """ self._logger.info("Encrypting data...") encrypted_data: npt.NDArray[np.object_] = np.vectorize( self.paillier_scheme.encrypt )(data) self._logger.info("Done encrypting data") return encrypted_data
[docs] def decrypt(self, data: npt.NDArray[np.object_]) -> npt.NDArray[np.float64]: """ Method to decrypt a dataset using the initialized Paillier scheme :param data: the dataset to decrypt :return: a decrypted dataset """ self._logger.info("Decrypting data...") decrypted_data: npt.NDArray[np.float64] = np.vectorize( self.paillier_scheme.decrypt )(data) self._logger.info("Done decrypting data") return decrypted_data
[docs]class Alice(DataOwner): """ Alice player in the MPC protocol """
[docs] def __init__( self, *args: Any, nr_of_threads: int = 4, time_label: str = "time", event_label: str = "event", **kwargs: Any, ) -> None: """ Initializes player Alice :param nr_of_threads: the number of threads to use for randomness generation :param time_label: the label used to represent the 'time' column in the data set :param event_label: the label used to represent the 'event' column in the data set :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self.nr_of_threads = nr_of_threads self.time_label = time_label self.event_label = event_label self._encrypted_group_data_: Optional[npt.NDArray[np.object_]] = None self._hidden_table: Optional[npt.NDArray[np.object_]] = None self._plain_table: Optional[npt.NDArray[np.int32]] = None self._indices_events: pd.Series = pd.Series(None, dtype=bool) self._mask_ht = None self._number_of_groups = None
@property def groups(self) -> int: """ Number of groups in the datasets :return: number of groups :raise ValueError: raised when number of groups is not available (yet) """ if self._number_of_groups is None: raise ValueError("Number of groups is not available yet") return self._number_of_groups @property def rows_in_hidden_table(self) -> int: """ Number of rows in the hidden table. Equals number of unique event times (ignoring censorings). :return: number of rows in the hidden table """ return self.data[self.time_label].loc[self.data[self.event_label].astype(bool)].nunique() # type: ignore[no-any-return] @property def cols_in_hidden_table(self) -> int: """ Number of columns in the hidden table. Two columns per group; one for the number of events on a given event time and one for the number of people at risk at that time. :return: number of columns in the hidden table """ return 2 * self.groups @property def _encrypted_group_data(self) -> npt.NDArray[np.object_]: """ Encrypted group data :return: the encrypted group data :raise ValueError: raised when the encrypted group data is not yet available. """ if self._encrypted_group_data_ is None: raise ValueError("Alice is missing some important data.") return self._encrypted_group_data_ @_encrypted_group_data.setter def _encrypted_group_data(self, data: npt.NDArray[np.object_]) -> None: self._encrypted_group_data_ = data @property def hidden_table(self) -> npt.NDArray[np.object_]: """ Hidden table :return: the constructed hidden table :raise ValueError: raised when hidden table is not yet available. """ if self._hidden_table is None: raise ValueError("Hidden table is not set yet.") return self._hidden_table @property def plain_table( self, ) -> npt.NDArray[np.int32]: """ Plain table :return: plaintext result of some computation in table format :raise ValueError: raised when plain table is not yet available. """ if self._plain_table is None: raise ValueError("Plain table is not set yet.") return self._plain_table
[docs] async def start_protocol(self) -> None: """ Starts and runs the protocol """ await asyncio.gather( *[ self.receive_paillier_scheme(), self.receive_number_of_groups(), ] ) self.start_randomness_generation() await self.receive_encrypted_group_data() self.compute_hidden_table() self.compute_factors() self.re_randomise_ht() self.stop_randomness_generation() self.generate_share() await self.send_share() await self.run_mpyc()
[docs] def start_randomness_generation(self) -> None: """ Kicks off the randomness generation. This boosts performance. In particular will this decrease the total runtime (as data owners can already generate randomness before they need it). """ # Total required randomness: # - once for re-randomizing the hidden table # - once for making an additive mask of the hidden table self.paillier_scheme.initialize_randomness( nr_of_threads=self.nr_of_threads, start_generation=True, max_size=2 * self.cols_in_hidden_table * self.rows_in_hidden_table, total=2 * self.cols_in_hidden_table * self.rows_in_hidden_table, )
[docs] async def receive_paillier_scheme(self) -> None: """ Method to receive the Paillier scheme that is used by party Bob. """ self._paillier_scheme = await self.receive_message( self.party_B, msg_id="paillier_scheme" )
[docs] async def receive_number_of_groups(self) -> None: """ Method to receive the number of groups identified by party Bob. """ self._number_of_groups = await self.receive_message( self.party_B, msg_id="number_of_groups" )
[docs] async def receive_encrypted_group_data(self) -> None: """ Method to receive the encrypted group data from party Bob. """ self._encrypted_group_data = await self.receive_message( self.party_B, msg_id="encrypted_group_data" )
[docs] def compute_hidden_table(self) -> None: """ Method to compute the hidden table of the protocol. """ self._logger.info("Computing Kaplan-Meier features from encrypted data...") self._sort_data() self._determine_events() self._remove_censored_and_duplicates() self._logger.info("Done computing Kaplan-Meier features from encrypted data")
def _sort_data(self) -> None: """ Sort data by time (ascending), then by event (descending). To obtain correct result, first sort all data by event, then sort all data by time. :raise AttributeError: raised when data is not a pandas dataframe """ if not isinstance(self.data, pd.DataFrame): raise AttributeError("Data is not a pandas dataframe") self._data = self.data.sort_values( by=[self.time_label, self.event_label], ascending=[True, False] ) self._encrypted_group_data = self._encrypted_group_data[ cast(slice, self.data.index) ] def _determine_events(self) -> None: """ Determine the indices of the events. """ self._indices_events = self.data[self.event_label] == 1 def _remove_censored_and_duplicates(self) -> None: """ Removes censored data and processes duplicates. :raise ValueError: raised when event indices are not determined """ diff: npt.NDArray[np.int32] = np.diff(self.data[self._indices_events][self.time_label]) # type: ignore[no-untyped-call] add = (np.nonzero(diff)[0] + 1).astype(np.int32) # type: ignore[attr-defined] add = np.insert(add, 0, np.int32(0)) # type: ignore[no-untyped-call] grouped_data = np.c_[ np.ones((self.data.shape[0], 1)), self._encrypted_group_data ] # Compute the result columns # Summing exposed_cols = grouped_data[::-1].cumsum(axis=0)[::-1] type_cols = np.add.reduceat(grouped_data[self._indices_events], add) # Removing if self._indices_events is None: raise ValueError("Indices of events are not determined (yet).") exposed_cols = exposed_cols[ self._indices_events & ~self.data["time"].duplicated(keep="first") ] self._hidden_table = np.c_[type_cols[:, 1:], exposed_cols[:, 1:]] self._plain_table = np.c_[type_cols[:, 0], exposed_cols[:, 0]]
[docs] def compute_factors(self) -> None: """Pre-computes several factors for in the computation of the log- rank statistic, leveraging information known by Alice only. Computes the following factors: dev_factors, var_factors, var_factors_2. These factors satisfy the following relations: Expected number of deaths in group i = dev_factors[i] * at_risk_group[i] Variance of deaths in group i = (var_factors_2[i] - var_factors[i] * at_risk_group[i]) * at_risk_group[i] """ at_risk_total = self.plain_table[:, 1] deaths_total = self.plain_table[:, 0] # Expected number of deaths(E) = # (deaths_total / at_risk_total) * # [at_risk_group] dev_factors = deaths_total / at_risk_total # Variance = # deaths_total * (at_risk_total - deaths_total) / # (at_risk_total**2 * (at_risk_total - 1)) * # [at_risk_group] * (at_risk_total - [at_risk_group]) # Note here that the denominator equals zero if at_risk_total # equals one, which is only possible in the last event time. # The variance should then also equal zero. Since # deaths_total is always strictly positive, we find that # necessarily at_risk_total - deaths_total = 0 if # at_risk_total = 0. Therefore, the following produces the # correct variance for every event time without dividing by # zero. var_factors = deaths_total * (at_risk_total - deaths_total) / at_risk_total ** 2 var_ind = at_risk_total != 1 var_factors[var_ind] /= at_risk_total[var_ind] - 1 var_factors_2 = var_factors * at_risk_total self._mpyc_factors = np.c_[dev_factors, var_factors, var_factors_2]
[docs] def generate_share(self) -> None: """ Generates additive secret shares. """ self._mpyc_data: npt.NDArray[np.float64] = np.vectorize( lambda _: self.signed_randomness() )(np.ndarray(self.hidden_table.shape)) self._logger.info("Generated share")
[docs] def mask_hidden_table( self, ) -> npt.NDArray[np.float64]: """ Masks the hidden table. :return: a masked hidden table """ return self.hidden_table - self.share
[docs] async def send_share(self) -> None: """ Sends additive secret share to party Bob. """ loop = asyncio.get_event_loop() masked_hidden_table = await loop.run_in_executor(None, self.mask_hidden_table) await self.send_message(self.party_B, masked_hidden_table, msg_id="share") self._logger.info("Sent share")
[docs] def signed_randomness(self) -> SupportsInt: """ Returns a signed random plaintext value. :return: signed random plaintext value """ return self.paillier_scheme.random_plaintext()
[docs] @staticmethod def re_randomize(ciphertext: PaillierCiphertext) -> None: """ Re-randomises a ciphertext :param ciphertext: ciphertext to randomize """ ciphertext.randomize()
[docs] def re_randomise_ht(self) -> None: """ Re-randomises the hidden table """ np.vectorize(self.re_randomize)(self.hidden_table)
[docs]class Bob(DataOwner): """ Bob player in the MPC protocol """
[docs] def __init__( self, *args: Any, paillier_scheme: Paillier = Paillier.from_security_parameter( key_length=2048, precision=0 ), **kwargs: Any, ) -> None: """ Initializes player Bob :param paillier_scheme: the Paillier scheme to use for encryption :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self._paillier_scheme = paillier_scheme self.encrypted_data: Optional[npt.NDArray[np.object_]] = None self._hidden_table: Optional[npt.NDArray[np.object_]] = None
@property def groups(self) -> int: """ Number of groups in the loaded dataset :return: number of groups """ return self.data.shape[1] # type: ignore[no-any-return]
[docs] async def start_protocol(self) -> None: """ Starts and runs the protocol """ await self.send_number_of_groups() loop = asyncio.get_event_loop() _, _, self.encrypted_data = await asyncio.gather( self.send_paillier_scheme(), self.send_number_of_groups(), loop.run_in_executor(None, self.encrypt, self.data), ) self.stop_randomness_generation() await self.send_encrypted_data() await self.receive_share() await self.run_mpyc()
[docs] async def send_paillier_scheme(self) -> None: """ Sends the used Paillier scheme to party Alice. """ await self.send_message( self.party_A, self.paillier_scheme, msg_id="paillier_scheme" )
[docs] async def send_number_of_groups(self) -> None: """ Sends the number of groups to party Alice. """ await self.send_message(self.party_A, self.groups, msg_id="number_of_groups")
[docs] async def send_encrypted_data(self) -> None: """ Sends the encrypted dataset to party Alice. """ await self.send_message( self.party_A, self.encrypted_data, msg_id="encrypted_group_data" )
[docs] async def receive_share(self) -> None: """ Receive additive secret share produced by party Alice. """ encrypted_share = await self.receive_message(self.party_A, msg_id="share") self._mpyc_data = await self.decrypt_share(encrypted_share) self._mpyc_factors = np.zeros((len(self._mpyc_data), 3), dtype=np.float64)
[docs] async def decrypt_share( self, data: npt.NDArray[np.object_] ) -> npt.NDArray[np.float64]: """ Decrypt share :param data: the dataset (share) to decrypt :return: decrypted data set """ loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.decrypt, data)