Private AI — Federated Learning with PySyft and PyTorch
During the last years, we all have witnessed an important and quick evolution in the fields of Artifical Intelligence and Machine Learning. This fast development is happening thanks to the improvement of computational power (made available by the last generations of GPUs and TPUs) and the enormous amount of data that has been accumulated over the years and is being created every second.
From conversational assistants to lung cancer detection, we can clearly see several applications and various benefits of AI development for our society. However, for the past years, this progress came with a cost: loss of privacy at some degree. The Cambridge Analytica Scandal was the event that raised the alarms for concerns on confidentiality and data privacy. Furthermore, increasing usage of data by tech companies, either small or big, led authorities in several jurisdictions to work on regulation and laws regarding data protection and privacy. GDPR in Europe is the most known example of such actions.
These concerns and regulations are not directly compatible with the development of AI and Machine Learning, as the models and algorithms always relied on the availability of data and the possibility to centralize it in big servers. In order to address this issue, a new field of research is getting the interest of ML researchers and practitioners: Private and Secure AI.
What is Private and Secure AI?
This new field consists of an ensemble of techniques that allow ML engineers to train models without having direct access to the data used for the training and avoid them to get any information about the data by the use of cryptography.
It seems like black magic doesn't it?
Don't worry… in a series of articles, I will show how it works and how we can apply it on our own Deep Learning models in Python with the open-source library PySyft.
This framework relies on three main techniques:
- Federated Learning
- Differential Privacy
- Secured Multi-Party Computation
In this article, I will cover Federated Learning and its application for SMS spam detection.
Maybe the easiest to understand concept in Private AI, Federated Learning is a technique to train AI models without having to move data to a central server. The term was first used by Google in a paper published in 2016.
Schema of a Federated Learning task
The main idea is that, instead of bringing the data to the model, we send the model to where the data is located.
As the data is located in several devices (which I will call workers from here) the model is sent to each worker and then sent back to the central server.
One simple example of Federated Learning in the real world happens with Apple devices. The application QuickType (Apple's text prediction tool) actually uses models that are sent time to time to iOS devices via WiFi, are trained locally with users' data and are sent back to Apple's central server with their weights updated.
PySyft is an open-source library built for Federate Learning and Privacy Preserving. It allows its users to perform private and secure Deep Learning. It is built as an extension of some DL libraries, such as PyTorch, Keras and Tensorflow.
If you are more interested you can also take a look at the paper published by OpenMined about the framework.
In this article, I will show a tutorial using the PySyft extension of PyTorch.
Getting started - Setting up the library
In order to install PySyft, it is recommended that you set up a conda environment first:
conda create -n pysyft python=3 conda activate pysyft # or source activate pysyft conda install jupyter notebook
You then install the package:
pip install syft
Please be sure that you also have PyTorch 1.0.1 or higher installed in your environment.
If you have an installation error regarding zstd, try to uninstall zstd and reinstall it.
pip uninstall zstd pip install --upgrade zstd
If you are still getting errors with the setup, you can alternatively use a Colab notebook and run the following line of code:
!pip install syft
SMS Spam detection with PySyft and PyTorch
The jupyter notebook with the code below is available on my GitHub page.
In this tutorial, I will simulate two workers, Bob and Anne’s devices, where the SMS messages will be stored. With PySyft we can simulate these remote machines by using the abstraction of VirtualWorker object.
First, we hook PyTorch:
import torch import syft as sy hook = sy.TorchHook(torch)
Then, we create VirtualWorkers:
bob = sy.VirtualWorker(hook, id="bob") anne = sy.VirtualWorker(hook, id="anne")
We can now send tensors to the workers with the method
.send(worker). For example:
x = torch.Tensor([2,2,2]).send(bob) print(x)
You will probably get something like that as output:
(Wrapper)>[PointerTensor | me:79601866508 -> bob:62597613886]
You can also check where is located the tensor the pointer is pointing to:
We can see that the tensor is located at a VirtualWorker called "bob" and this worker has one tensor.
Now you can do remote operations using these pointers:
y = torch.Tensor([1,2,3]).send(bob) sum = x + y print(sum)
(Wrapper)>[PointerTensor | me:40216858934 -> bob:448194605]
You can see that after the operation we get a pointer as a return. To get the tensor back you need to use the method
sum = sum.get() print(sum)
tensor([3., 4., 5.])
The most amazing thing is that we can effectuate all the operations provided by the PyTorch API on these pointers, such as compute losses, take gradients back to zero, perform backpropagation, etc.
Now that you understand the basics of VirtualWorkers and Pointers we can train our model using Federated Learning.
Preparing data and sending it to remote workers
To simulate the remote data we will use GSMS Spam Collection Data Set available on the UCI Machine Learning Repository. It consists of c. 5500 SMS messages, of which around 13% are spam messages. We will send about half of the messages to Bob's device and the other half to Anne's device.
For this project, I performed some text and data preprocessing that I will not show here, but if you are interested you can take a look at the script I used available on my GitHub page. Please also note that in a real-life case this preprocessing will be done in each user's device.
Let's load the processed data:
# Loading data inputs = np.load('./data/inputs.npy') inputs = torch.tensor(inputs) labels = np.load('./data/labels.npy') labels = torch.tensor(labels) # splitting training and test data pct_test = 0.2 train_labels = labels[:-int(len(labels)*pct_test)] train_inputs = inputs[:-int(len(labels)*pct_test)] test_labels = labels[-int(len(labels)*pct_test):] test_inputs = inputs[-int(len(labels)*pct_test):]
We then split the datasets in two and send it to the workers with the class
When training in PyTorch, we use DataLoaders to iterate over the batches. With PySyft we can do a similar iteration with FederatedDataLoaders, where the batches come from several devices, in a federated manner.
Training a GRU Model
For this task, I decided to use a classifier based on a 1-layer GRU network. Unfortunately, the current version of PySyft does not support the RNNs modules of PyTorch yet. However, I was able to handcraft a simple GRU network with linear layers, which are supported by PySyft.
Let's initiate the model!
from handcrafted_GRU import GRU # Training params EPOCHS = 15 CLIP = 5 # gradient clipping - to avoid gradient explosion lr = 0.1 BATCH_SIZE = 32 # Model params EMBEDDING_DIM = 50 HIDDEN_DIM = 10 DROPOUT = 0.2 # Initiating the model model = GRU(vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, embedding_dim=EMBEDDING_DIM, dropout=DROPOUT)
And now train it!
Please note the lines 8, 12, 13 and 27. These are the steps that differentiate a centralised training in PyTorch from a federated training with PySyft.
After getting the model back at the end of the training loop, we can use it to evaluate its performance on local or remote test sets with a similar approach. In this case, I was able to achieve over 97.5% of AUC score, showing that training models in a federated manner does not hurt performance. However, we can notice an increase in overall time computation.
We can see that with the PySyft library and its PyTorch extension, we can perform operations with tensor pointers such as we can do with PyTorch API (but for some limitations that are still to be addressed).
Thanks to this, we were able to train a spam detector model without having any access to the remote and private data: for each batch, we sent the model to the current remote worker and got it back to the local machine before sending it to the worker of the next batch.
There is however one limitation of this method: by getting the model back we can still have access to some private information. Let’s say Bob had only one SMS on his machine. When we get the model back, we can just check which embeddings of the model changed and we will know which were the tokens (words) of the SMS.
In order to address this issue, there are two solutions: Differential Privacy and Secured Multi-Party Computation (SMPC). Differential Privacy would be used to make sure the model does not give access to some private information. SMPC, which is one kind of Encrypted Computation, in return allows you to send the model privately so that the remote workers which have the data cannot see the weights you are using.
I will show how we perform these techniques with PySyft in the next article.
Feel free to give me feedback and ask questions!
If you are interested in learning more about Secure and Private AI and how to use PySyft you can also check out this free course on Udacity. It's a great course for beginners taught by Andrew Trask, the founder of the OpenMined Initiative.