Source code for kaplan_meier.data_owner

"""
Module implements the data owners Alice and Bob.
"""

from __future__ import annotations

import asyncio
from typing import Any, Optional, SupportsInt, Union, cast

import numpy as np
import pandas as pd

from tno.mpc.communication import Pool
from tno.mpc.encryption_schemes.paillier import Paillier, PaillierCiphertext
from tno.mpc.encryption_schemes.utils.fixed_point import FixedPoint

from .player import Player


[docs]class DataOwner(Player): """ Data owner in the MPC protocol """
[docs] def __init__( self, data: pd.DataFrame, pool: Pool, *args: Any, **kwargs: Any, ) -> None: """ Initializes data owner :param data: the data to use for this data owner :param pool: a communication pool :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self._paillier_scheme: Optional[Paillier] = None self._data = data self.pool = pool
[docs] async def receive_message(self, party: str, msg_id: Optional[str] = None) -> Any: """ Receives a message from a party (belonging to an optional message identifier) :param party: the party to receive a message from :param msg_id: the message id :return: the received message """ return await self.pool.recv(party, msg_id=msg_id)
[docs] async def send_message( self, receiver: str, message: Any, msg_id: Optional[str] = None ) -> None: """ Sends a message to a party (with an optional message identifier) :param receiver: the party to send a message to :param message: the message to send :param msg_id: the message id """ await self.pool.send(receiver, message, msg_id=msg_id)
@property def records(self) -> int: """ Number of records in the loaded dataset :return: number of records """ return self.data.shape[0] @property def groups(self) -> int: """ Number of groups in the loaded datasets :return: number of groups :raise NotImplementedError: raised when not implemented """ raise NotImplementedError() @property def data(self) -> Union[pd.DataFrame, np.ndarray[np.int32]]: """ The loaded dataset :return: dataset :raise ValueError: raised when there is no data available """ if self._data is None: raise ValueError("No event data available yet.") return self._data @property def paillier_scheme(self) -> Paillier: """ The Paillier scheme :return: Paillier scheme :raise ValueError: raised when Paillier scheme is not available yet. """ if self._paillier_scheme is None: raise ValueError("There is no Paillier scheme available yet.") return self._paillier_scheme
[docs] def stop_randomness_generation(self) -> None: """ Stop generation of randomness. """ self.paillier_scheme.randomness.shut_down()
[docs] def encrypt(self, data: np.ndarray[np.float64]) -> np.ndarray[PaillierCiphertext]: # type: ignore[type-var] """ Method to encrypt a dataset using the initialized Paillier scheme :param data: the dataset to encrypt :return: an encrypted dataset """ self._logger.info("Encrypting data...") encrypted_data: np.ndarray[PaillierCiphertext] = np.vectorize(self.paillier_scheme.encrypt)(data) # type: ignore[attr-defined, type-var] self._logger.info("Done encrypting data") return encrypted_data
[docs] def decrypt(self, data: np.ndarray[PaillierCiphertext]) -> np.ndarray[Any]: # type: ignore[type-var] """ Method to decrypt a dataset using the initialized Paillier scheme :param data: the dataset to decrypt :return: a decrypted dataset """ self._logger.info("Decrypting data...") decrypted_data: np.ndarray[Any] = np.vectorize(self.paillier_scheme.decrypt)(data) # type: ignore[attr-defined] self._logger.info("Done decrypting data") return decrypted_data
[docs]class Alice(DataOwner): """ Alice player in the MPC protocol """
[docs] def __init__(self, *args: Any, nr_of_threads: int = 4, **kwargs: Any) -> None: """ Initializes player Alice :param nr_of_threads: the number of threads to use for randomness generation :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self.nr_of_threads = nr_of_threads self._encrypted_group_data_: Optional[np.ndarray[PaillierCiphertext]] = None # type: ignore[type-var] # pylint: disable=unsubscriptable-object self._hidden_table: Optional[np.ndarray[PaillierCiphertext]] = None # type: ignore[type-var] # pylint: disable=unsubscriptable-object self._plain_table: Optional[ np.ndarray[np.int32] # pylint: disable=unsubscriptable-object ] = None self._indices_events: pd.Series[bool] = pd.Series(None, dtype=bool) # type: ignore[call-arg, type-var] # pylint: disable=unsubscriptable-object self._mask_ht = None self._number_of_groups = None
@property def groups(self) -> int: """ Number of groups in the datasets :return: number of groups :raise ValueError: raised when number of groups is not available (yet) """ if self._number_of_groups is None: raise ValueError("Number of groups is not available yet") return self._number_of_groups @property def rows_in_hidden_table(self) -> int: """ Number of rows in the hidden table. Equals number of unique event times (ignoring censorings). :return: number of rows in the hidden table """ return self.data["time"].loc[self.data["event"].astype(bool)].nunique() @property def cols_in_hidden_table(self) -> int: """ Number of columns in the hidden table. Two columns per group; one for the number of events on a given event time and one for the number of people at risk at that time. :return: number of columns in the hidden table """ return 2 * self.groups @property def _encrypted_group_data(self) -> np.ndarray[PaillierCiphertext]: # type: ignore[type-var] # pylint: disable=unsubscriptable-object """ Encrypted group data :return: the encrypted group data :raise ValueError: raised when the encrypted group data is not yet available. """ if self._encrypted_group_data_ is None: raise ValueError("Alice is missing some important data.") return self._encrypted_group_data_ @_encrypted_group_data.setter def _encrypted_group_data(self, data: np.ndarray[PaillierCiphertext]) -> None: # type: ignore[type-var] # pylint: disable=unsubscriptable-object self._encrypted_group_data_ = data @property def hidden_table(self) -> np.ndarray[PaillierCiphertext]: # type: ignore[type-var] # pylint: disable=unsubscriptable-object """ Hidden table :return: the constructed hidden table :raise ValueError: raised when hidden table is not yet available. """ if self._hidden_table is None: raise ValueError("Hidden table is not set yet.") return self._hidden_table @property def plain_table( self, ) -> np.ndarray[np.int32]: # pylint: disable=unsubscriptable-object """ Plain table :return: plaintext result of some computation in table format :raise ValueError: raised when plain table is not yet available. """ if self._plain_table is None: raise ValueError("Plain table is not set yet.") return self._plain_table
[docs] async def start_protocol(self) -> None: """ Starts and runs the protocol """ await asyncio.gather( *[ self.receive_paillier_scheme(), self.receive_number_of_groups(), ] ) self.start_randomness_generation() await self.receive_encrypted_group_data() self.compute_hidden_table() self.compute_factors() self.re_randomise_ht() self.stop_randomness_generation() self.generate_share() await self.send_share() await self.run_mpyc()
[docs] def start_randomness_generation(self) -> None: """ Kicks off the randomness generation. This boosts performance. In particular will this decrease the total runtime (as data owners can already generate randomness before they need it). """ # Total required randomness: # - once for re-randomizing the hidden table # - once for making an additive mask of the hidden table self.paillier_scheme.initialize_randomness( nr_of_threads=self.nr_of_threads, start_generation=True, max_size=2 * self.cols_in_hidden_table * self.rows_in_hidden_table, total=2 * self.cols_in_hidden_table * self.rows_in_hidden_table, )
[docs] async def receive_paillier_scheme(self) -> None: """ Method to receive the Paillier scheme that is used by party Bob. """ self._paillier_scheme = await self.receive_message( self.party_B, msg_id="paillier_scheme" )
[docs] async def receive_number_of_groups(self) -> None: """ Method to receive the number of groups identified by party Bob. """ self._number_of_groups = await self.receive_message( self.party_B, msg_id="number_of_groups" )
[docs] async def receive_encrypted_group_data(self) -> None: """ Method to receive the encrypted group data from party Bob. """ self._encrypted_group_data = await self.receive_message( self.party_B, msg_id="encrypted_group_data" )
[docs] def compute_hidden_table(self) -> None: """ Method to compute the hidden table of the protocol. """ self._logger.info("Computing Kaplan-Meier features from encrypted data...") self._sort_data() self._determine_events() self._remove_censored_and_duplicates() self._logger.info("Done computing Kaplan-Meier features from encrypted data")
def _sort_data(self) -> None: """ Sort data by time (ascending), then by event (descending). To obtain correct result, first sort all data by event, then sort all data by time. :raise AttributeError: raised when data is not a pandas dataframe """ if not isinstance(self.data, pd.DataFrame): raise AttributeError("Data is not a pandas dataframe") self._data = self.data.sort_values( by=["time", "event"], ascending=[True, False] ) self._encrypted_group_data = self._encrypted_group_data[ cast(slice, self.data.index) ] def _determine_events(self) -> None: """ Determine the indices of the events. """ self._indices_events = cast("pd.Series[bool]", self.data["event"] == 1) def _remove_censored_and_duplicates(self) -> None: """ Removes censored data and processes duplicates. :raise ValueError: raised when event indices are not determined """ diff: np.ndarray[np.int32] = np.diff(self.data[self._indices_events]["time"]) # type: ignore[arg-type, call-overload] # pylint: disable=unsubscriptable-object add = (np.nonzero(diff)[0] + 1).astype(np.int32) add = np.insert(add, 0, np.int32(0)) grouped_data = np.c_[ # type: ignore[attr-defined] np.ones((self.data.shape[0], 1)), self._encrypted_group_data ] # Compute the result columns # Summing exposed_cols = grouped_data[::-1].cumsum(axis=0)[::-1] type_cols = np.add.reduceat(grouped_data[self._indices_events], add) # type: ignore[attr-defined] # Removing if self._indices_events is None: raise ValueError("Indices of events are not determined (yet).") exposed_cols = exposed_cols[ self._indices_events & ~self.data["time"].duplicated(keep="first") ] self._hidden_table = np.c_[type_cols[:, 1:], exposed_cols[:, 1:]] # type: ignore[attr-defined] self._plain_table = np.c_[type_cols[:, 0], exposed_cols[:, 0]] # type: ignore[attr-defined]
[docs] def compute_factors(self) -> None: """Pre-computes several factors for in the computation of the log- rank statistic, leveraging information known by Alice only. Computes the following factors: dev_factors, var_factors, var_factors_2. These factors satisfy the following relations: Expected number of deaths in group i = dev_factors[i] * at_risk_group[i] Variance of deaths in group i = (var_factors_2[i] - var_factors[i] * at_risk_group[i]) * at_risk_group[i] """ at_risk_total = self.plain_table[:, 1] deaths_total = self.plain_table[:, 0] # Expected number of deaths(E) = # (deaths_total / at_risk_total) * # [at_risk_group] dev_factors = deaths_total / at_risk_total # Variance = # deaths_total * (at_risk_total - deaths_total) / # (at_risk_total**2 * (at_risk_total - 1)) * # [at_risk_group] * (at_risk_total - [at_risk_group]) # Note here that the denominator equals zero if at_risk_total # equals one, which is only possible in the last event time. # The variance should then also equal zero. Since # deaths_total is always strictly positive, we find that # necessarily at_risk_total - deaths_total = 0 if # at_risk_total = 0. Therefore, the following produces the # correct variance for every event time without dividing by # zero. var_factors = deaths_total * (at_risk_total - deaths_total) / at_risk_total ** 2 var_ind = at_risk_total != 1 var_factors[var_ind] /= at_risk_total[var_ind] - 1 var_factors_2 = var_factors * at_risk_total self._mpyc_factors = np.c_[dev_factors, var_factors, var_factors_2] # type: ignore[attr-defined]
[docs] def generate_share(self) -> None: """ Generates additive secret shares. """ self._mpyc_data: np.ndarray[FixedPoint] = np.vectorize(lambda _: self.signed_randomness())(np.ndarray(self.hidden_table.shape)) # type: ignore[attr-defined, type-var] # pylint: disable=unsubscriptable-object self._logger.info("Generated share")
[docs] def mask_hidden_table( self, ) -> np.ndarray[np.float64]: # pylint: disable=unsubscriptable-object """ Masks the hidden table. :return: a masked hidden table """ return cast("np.ndarray[np.float64]", self.hidden_table - self.share)
[docs] async def send_share(self) -> None: """ Sends additive secret share to party Bob. """ loop = asyncio.get_event_loop() masked_hidden_table = await loop.run_in_executor(None, self.mask_hidden_table) await self.send_message(self.party_B, masked_hidden_table, msg_id="share") self._logger.info("Sent share")
[docs] def signed_randomness(self) -> SupportsInt: """ Returns a signed random plaintext value. :return: signed random plaintext value """ return self.paillier_scheme.random_plaintext()
[docs] @staticmethod def re_randomize(ciphertext: PaillierCiphertext) -> None: """ Re-randomises a ciphertext :param ciphertext: ciphertext to randomize """ ciphertext.randomize()
[docs] def re_randomise_ht(self) -> None: """ Re-randomises the hidden table """ np.vectorize(self.re_randomize)(self.hidden_table) # type: ignore[attr-defined]
[docs]class Bob(DataOwner): """ Bob player in the MPC protocol """
[docs] def __init__( self, *args: Any, paillier_scheme: Paillier = Paillier.from_security_parameter( key_length=2048, precision=0 ), **kwargs: Any, ) -> None: """ Initializes player Bob :param paillier_scheme: the Paillier scheme to use for encryption :param args: arguments to pass on to base class :param kwargs: keyword arguments to pass on to base class """ super().__init__(*args, **kwargs) self._paillier_scheme = paillier_scheme self.encrypted_data: Optional[np.ndarray[PaillierCiphertext]] = None # type: ignore[type-var] # pylint: disable=unsubscriptable-object self._hidden_table: Optional[np.ndarray[PaillierCiphertext]] = None # type: ignore[type-var] # pylint: disable=unsubscriptable-object
@property def groups(self) -> int: """ Number of groups in the loaded dataset :return: number of groups """ return self._data.shape[1]
[docs] async def start_protocol(self) -> None: """ Starts and runs the protocol """ await self.send_number_of_groups() loop = asyncio.get_event_loop() _, _, self.encrypted_data = await asyncio.gather( self.send_paillier_scheme(), self.send_number_of_groups(), loop.run_in_executor(None, self.encrypt, self.data), ) self.stop_randomness_generation() await self.send_encrypted_data() await self.receive_share() await self.run_mpyc()
[docs] async def send_paillier_scheme(self) -> None: """ Sends the used Paillier scheme to party Alice. """ await self.send_message( self.party_A, self.paillier_scheme, msg_id="paillier_scheme" )
[docs] async def send_number_of_groups(self) -> None: """ Sends the number of groups to party Alice. """ await self.send_message(self.party_A, self.groups, msg_id="number_of_groups")
[docs] async def send_encrypted_data(self) -> None: """ Sends the encrypted dataset to party Alice. """ await self.send_message( self.party_A, self.encrypted_data, msg_id="encrypted_group_data" )
[docs] async def receive_share(self) -> None: """ Receive additive secret share produced by party Alice. """ encrypted_share = await self.receive_message(self.party_A, msg_id="share") self._mpyc_data = cast( "np.ndarray[FixedPoint]", await self.decrypt_share(encrypted_share) # type: ignore[type-var] ) self._mpyc_factors = np.zeros((len(self._mpyc_data), 3), dtype=np.float64)
[docs] async def decrypt_share(self, data: np.ndarray[PaillierCiphertext]) -> Any: # type: ignore[type-var] # pylint: disable=unsubscriptable-object """ Decrypt share :param data: the dataset (share) to decrypt :return: decrypted data set """ loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.decrypt, data)