# Generate Dataset using sklearn
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
= make_classification(
X, y =2, n_redundant=0, n_informative=2, n_clusters_per_class=1, random_state=7
n_features
)
= train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_test, y_train, y_test
# Plot the data
0], X[:, 1], marker="o", c=y, s=25, edgecolor="k")
plt.scatter(X[:, plt.show()
Federated learning is a technique for machine learning that uses decentralized clients to train on local data and send information back to a server without revealing the local data. Federated learning helps models be trained with greater privacy and has many natural applications.
A High-Level Look
How does Federated Learning Work?
- An initial model is established on the server and the weights are sent out to all clients
- Each client trains the model on its own local data and sends the weights or gradients back to the server
- Aggregate the weights of each client
- Update the server’s model with the aggregated weights and send the new weights to each client
- Repeat steps 2-5 for some number of iterations
How do we aggregate the weights?
The following two algorithms come from the paper Communication-Efficient Learning of Deep Networks from Decentralized Data by H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas from Google in 2016.
FedSGD
A simple way to update the server’s model is to update the parameters for every gradient that gets sent from a client. This method is called FedSGD and is defined as follows:
\[g_k = \nabla F_k(w_t)\] \[ w_{t+1} \leftarrow w_t - \eta \sum_{k=1}^{K}\frac{n_k}{n}g_k\]
For each client k, we do one single step of gradient descent and then average the weights together.
FedAVG
FedAVG is a modification of FedSGD that trains each client for multiple epochs and then averages the weights together. This method uses less communication than FedSGD and is one of the most commonly used algorithms. It is defined in the aformentioned paper as follows:
Example with code
We first generate a simple dataset that can be used to classify two classes.We then train a centralized model using sklearn and plot the decision boundary. Next, we train a federated model using FedAVG and plot the decision boundary. Finally, we compare the accuracy of the two models.
import numpy as np
from sklearn.linear_model import SGDClassifier
# Train sklearn SGDClassifier model
= SGDClassifier(loss="log_loss")
model
model.fit(X_train, y_train)
# Plot the decision boundary
= np.linspace(X_test.min()-3, X_test.max()+3, 100)
x1 = np.linspace(y_test.min()-3, y_test.max()+3, 100)
x2 = np.meshgrid(x1, x2)
xx1, xx2 = np.c_[xx1.ravel(), xx2.ravel()]
X_grid = model.predict_proba(X_grid)[:, 1].reshape(xx1.shape)
probs
0.5], linewidths=1, colors="black")
plt.contour(xx1, xx2, probs, [0], X_test[:, 1], marker="o", c=y_test, s=25, edgecolor="k")
plt.scatter(X_test[:,
plt.show()
# Print the accuracy
= model.score(X_test, y_test) * 100.0
accuracy print(f"Accuracy: {accuracy:.2f}%")
Accuracy: 85.00%
import numpy as np
from sklearn.linear_model import SGDClassifier
= 3
n_clients = 3
n_epochs = 1
n_rounds
= [SGDClassifier(loss="log_loss") for _ in range(n_clients)]
client_models = SGDClassifier(loss="log_loss")
server_model
# Split data into clients
= np.array_split(X_train, n_clients)
X_clients = np.array_split(y_train, n_clients)
y_clients
# Initialize server coefficients to 0
= np.zeros((1, 2))
server_model.coef_ = np.zeros(1)
server_model.intercept_ = np.array([0, 1])
server_model.classes_
for _ in range(n_rounds):
# Set client models to be the same as the server model
for client_model in client_models:
= server_model.coef_
client_model.coef_ = server_model.intercept_
client_model.intercept_
# Train each client model on its own data
for client_model, X, y in zip(client_models, X_clients, y_clients):
# Split data into batches
= np.array_split(X, n_epochs)
X_batches = np.array_split(y, n_epochs)
y_batches
for _ in range(n_epochs):
for X_batch, y_batch in zip(X_batches, y_batches):
=[0, 1])
client_model.partial_fit(X_batch, y_batch, classes
# Aggregate the client models using FedAVG using the number of samples as the weights
= [len(X) for X in X_clients]
n_samples = [n / sum(n_samples) for n in n_samples]
weights
= np.average(
server_model.coef_ for client_model in client_models], axis=0, weights=weights
[client_model.coef_
)= np.average(
server_model.intercept_ for client_model in client_models], axis=0, weights=weights
[client_model.intercept_
)
# Plot the decision boundary
= np.linspace(X_test.min()-3, X_test.max()+3, 100)
x1 = np.linspace(y_test.min()-3, y_test.max()+3, 100)
x2 = np.meshgrid(x1, x2)
xx1, xx2 = np.c_[xx1.ravel(), xx2.ravel()]
X_grid = model.predict_proba(X_grid)[:, 1].reshape(xx1.shape)
probs
0.5], linewidths=1, colors="black")
plt.contour(xx1, xx2, probs, [0], X_test[:, 1], marker="o", c=y_test, s=25, edgecolor="k")
plt.scatter(X_test[:,
plt.show()
# Print the accuracy
= server_model.score(X_test, y_test) * 100.0
accuracy print(f"Accuracy: {accuracy:.2f}%")
Accuracy: 85.00%
Now we can see that the federated model has a similar accuracy to the centralized model. If we look at the weights of the server model, we can see that they are similar to the weights of the centralized model.
print(f"Centralized Model Weights: w={model.coef_[0]}, b={model.intercept_[0]}")
print(f"Federated Model Weights: w={server_model.coef_[0]}, b={server_model.intercept_[0]}")
Centralized Model Weights: w=[-6.13807248 21.28558495], b=5.731260349455407
Federated Model Weights: w=[-5.8260368 24.13905334], b=6.432089614040729
Issues with Federated Learning
Federated Learning is a promising approach to training machine learning models on decentralized data. There are situations where Federated Learning is naturally the best solution given how the data is split up. However, there are still many issues that need to be considered before using it.
While Federated Learning helps increase privacy, it does not guarantee privacy. There are many attacks that use either malicious models or gradients to extract information about the data. To have privacy Federated Learning must be combined with something such as differential privacy or fully homomorphic encryption.
Another issue is that clients typically have different amounts of data and with different usage patterns that might not be representative of the entire dataset. This is defined as Unbalanced data and Non-IID data in the Federated Learning literature.
Finally, Federated Learning has to deal with the issue of limited communication and a large number of clients. Some clients have limited bandwidth and are offline for long periods of time. This means that the server model has to be able to handle clients that are not always available.
Further Reading
Other algorithms for federated learning include:
1. FedDyn
2. Sub-FedAvg
3. FedAvgM
4. FedAdam
Frameworks for federated learning include:
1. TensorFlow Federated
2. Flower