Federated Learning for Tabular Data Using Flower Framework.
parmarjatin4911@gmail.com
Posted on January 20, 2024
With advancements in AI development, the major issue facing the AI community is dealing with users’ data privacy and security concerns. Users demand that their privacy be respected by not using data but good Machine Learning models require lots of data. This has brought the idea of federated learning in handling user’s concerns about data privacy and security concerns.
Federated learning is a decentralized machine learning technique used to train machine learning algorithms by using multiple local datasets across multiple devices without exchanging data with the central servers. Like the normal machine learning technique, federated learning uses all types of data. For example: text, numeric, images, videos, time series, voice, etc.
Examples of applications that use federated learning include Google Keyboard, Google Assistant, and Apple’s Siri.
How Federated Learning Works
Unlike the centralized machine learning technique, federated learning works differently. There are two main aspects involved in federated learning, a client and a server.
During the model training, data is first distributed across devices (clients). Each device then trains a model using its local data. The model weights are then sent to a central server where they are aggregated into a single model. The aggregated model is then sent back to the devices for more training to achieve better performance and the loop continues.
The figure below shows the steps involved in federated learning.
In this article, you will learn how to write a federated learning application for tabular data using the Flower federated learning framework.
Why Flower Framework?
Easy to customize.
Easy to extend and override to build new state-of-the-art systems.
Easy to read and understand.
It supports a variety of deep learning and machine learning frameworks like TensorFlow, Scikit-learn, PyTorch, JAX, MXNet, etc.
Prerequisites
To follow through this tutorial, you need to:
Have Python installed on your machine.
Have the following libraries installed in your machine: Scikit-Learn, Pandas, NumPy and Matplotlib or Seaborn.
Installing and Setting Up Flower
Flower is an open-source federated learning framework for building AI applications that can be trained on data distributed across multiple devices. It is designed to be scalable, efficient, and easy to use.
Before you install Flower, you need to set up a development virtual environment. In this case, we use Pipenv for our virtual environment.
First, create a new directory using the command below:
$ mkdir flower-federated
$ cd flower-federated
Create a virtual environment using Pipenv, run the command below:
$ pipenv install flower
The command installs Flower at the same time creating a virtual environment for the project.
Run the command below to activate the virtual environment:
$ pipenv shell
Next, install the dependencies we shall use, that is Scikit-Learn, Pandas, Numpy and Matplotlib or Seaborn using:
$ pipenv install *package
Data Preparation
Before you can write your application, you need to prepare data for each client. The data for both clients should be diverse enough to be able to simulate different devices’ data.
In this case, we will use the bank customer churn dataset from Kaggle. The dataset is diverse with customers from different countries which we can use to simulate the ideal situation of federated learning. For this tutorial, we used the country feature to split the dataset into two for each client.
A notebook for the data preparation can be found on GitHub.
How does the team at Uber manage to keep their data organized and their team united? Comet’s experiment tracking. Learn more from Uber’s Olcay Cirit.
Writing the Client Application
First, in your project directory, create a utils.py file. The utils.py file is used to create helper functions for the client file.
Import Dependencies
Import the necessary libraries.
from types import new_class
from typing import Tuple, Union, List
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
Data Preprocessing
Transform the data into a tuple and list of numpy arrays.
XY = Tuple[np.ndarray, np.ndarray]
Dataset = Tuple[XY, XY]
LogRegParams = Union[XY, Tuple[np.ndarray]]
XYList = List[XY]
Set model parameters
After model training in the clients, the model parameters are sent to the server for aggregation. Therefore, we will create functions for getting model parameters, setting model parameters and default parameters for the client devices.
def get_model_parameters(model: LogisticRegression) -> LogRegParams:
"""Returns the parameters of a scikit-learn LogisticRegression model."""
if model.fit_intercept:
params = [model.coef_, model.intercept_]
else:
params = [model.coef_,]
return params
def set_model_params(model: LogisticRegression, params: LogRegParams) -> LogisticRegression:
"""Sets the parameters of a scikit-learnLogisticRegression model."""
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model
def set_initial_params(model: LogisticRegression):
"""Sets initial parameters as zeros Required since model params are
uninitialized until model.fit is called.
But the server asks for initial parameters from clients at launch. Refer
to sklearn.linear_model.LogisticRegression documentation for more
information.
"""
n_classes = 2 # churn dataset has 2 classes
n_features = 13 # Number of features in dataset
model.classes_ = np.array([i for i in range(2)])
model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))
Load Data
Create two functions load_data_client1() and for load_data_client1() loading the initially created datasets.
def load_data_client1() -> Dataset:
data = pd.read_csv('france.csv')
data.reset_index(drop=True)
df =np.array(data)
X = df[:,:-1]
y =df[:,-1]
# Standardizing the features
x = StandardScaler().fit_transform(X)
""" Select the 80% of the data as Training data and 20% as test data """
x_train,x_test,y_train,y_test= train_test_split(x,y, test_size=0.2, random_state=42, shuffle=True, stratify=y)
return (x_train, y_train), (x_test, y_test)
""" Read data for the other client """
def load_data_client2() -> Dataset:
data = pd.read_csv('germany.csv')
data.reset_index(drop=True)
df =np.array(data)
X = df[:,:-1]
y =df[:,-1]
# Standardizing the features
x = StandardScaler().fit_transform(X)
""" Select the 80% of the data as Training data and 20% as test data """
x_train,x_test,y_train,y_test= train_test_split(x,y, test_size=0.2, random_state=42, shuffle=True, stratify=y)
return (x_train, y_train), (x_test, y_test)
def shuffle(X: np.ndarray, y: np.ndarray) -> XY:
"""Shuffle X and y Datasets"""
randon_gen = np.random.default_rng()
perm = randon_gen.permutation(len(X))
return X[perm], y[perm]
def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
"""Split X and y Datasets into a variety of partitions."""
return list(
zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions))
)
Next, create two client files for client devices. On each of the files, insert the following code blocks. That is, client1.py and client2.py.
Import dependencies
Import the necessary libraries.
import warnings
import flwr as fl
import numpy as np
import sys
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import utils
Client File
The client file is used to train a model using the device’s local data. In this case, we are training the model using the data we loaded in the utils.py file in the load_data_client1() function.
if name == "main":
""" Load Bank Customer churn data from France """
(X_train, y_train), (X_test, y_test) = utils.load_data_client1()
# Split train set into 10 partitions and randomly use one for training.
partition_id = np.random.choice(5)
(X_train, y_train) = utils.partition(X_train, y_train, 5)[partition_id]
"""Create a Logistic Regression Model """
model = LogisticRegression(
solver= 'saga',
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
# Setting initial parameters, akin to model.compile for keras models
utils.set_initial_params(model)
""" Define Flower client """
class MnistClient(fl.client.NumPyClient):
def get_parameters(self): # type: ignore
return utils.get_model_parameters(model)
def fit(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(X_train, y_train)
print(f"Training finished for round {config['rnd']}")
return utils.get_model_parameters(model), len(X_train), {}
def evaluate(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
preds = model.predict_proba(X_test)
all_classes = {'1','0'}
loss = log_loss(y_test, preds, labels=[1,0])
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"accuracy": accuracy}
""" Start Flower client """
fl.client.start_numpy_client(
server_address = "localhost:5040",
client=MnistClient())
On the second client2.py file, copy and paste the previous code and update the load data.
""" Load Bank Customer churn data """
(X_train, y_train), (X_test, y_test) = utils.load_data_client2()
Writing the Server Application
The server application is used to send a default model to the client and get the clients’ model parameters and use them to train a final and better model.
Create a server.py file and insert the following block of code to build a server application.
Import the necessary libraries
import flwr as fl
import utils
import sys
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict
import pandas as pd
import numpy as np
Training the model
For the training, we will use Logistic Regression which is good for classification models.
def fit_round(rnd: int) -> Dict:
"""Send number of training rounds to client."""
return {"rnd": rnd}
def get_eval_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""
# Load test data here to avoid the overhead of doing it in `evaluate` itself
_, (X_test, y_test) = utils.load_data()
# The `evaluate` function will be called after every round
def evaluate(parameters: fl.common.Weights):
# Update model with the latest parameters
utils.set_model_params(model, parameters)
preds = model.predict_proba(X_test)
loss = log_loss(y_test, preds, labels=[1,0])
accuracy = model.score(X_test, y_test)
res = pd.DataFrame(preds)
res.index = pd.DataFrame(X_test).index # it's important for comparison
res.columns = ["prediction", 'real']
res.to_csv("prediction_results.csv")
return {"Aggregated Results: loss ":loss}, {"accuracy": accuracy}
return evaluate
Start Flower server for ten rounds of federated learning
if name == "main":
model = LogisticRegression(
solver= 'saga',
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
utils.set_initial_params(model)
strategy = fl.server.strategy.FedAvg(
min_available_clients=2,
eval_fn=get_eval_fn(model),
on_fit_config_fn=fit_round,
)
fl.server.start_server(
server_address = "localhost:5040",
strategy=strategy,
config={"num_rounds": 10},
)
Once you’ve finished writing the server application, you can now run the application. Open your PowerShell or terminal and navigate to your project directory path. Start the server using the command below:
$ python server.py
Expect output:
INFO flower 2022–11–28 21:11:50,606 | app.py:134 | Flower server running (10 rounds), SSL is disabled
INFO flower 2022–11–28 21:11:50,606 | server.py:84 | Initializing global parameters
INFO flower 2022–11–28 21:11:50,606 | server.py:256 | Requesting initial parameters from one random client
INFO flower 2022–11–28 21:12:45,134 | server.py:259 | Received initial parameters from one random client
INFO flower 2022–11–28 21:12:45,134 | server.py:86 | Evaluating initial parameters
INFO flower 2022–11–28 21:12:45,197 | server.py:89 | initial parameters (loss, other metrics): {'Aggregated Results: loss ': 0.6931471805599453}, {'accuracy': 0.8384845463609173}
INFO flower 2022–11–28 21:12:45,197 | server.py:99 | FL starting
DEBUG flower 2022–11–28 21:18:30,120 | server.py:203 | fit_round: strategy sampled 2 clients (out of 2)
INFO flower 2022–11–28 21:18:30,292 | server.py:114 | fit progress: (1, {'Aggregated Results: loss ': 0.42021533430137353}, {'accuracy': 0.8414755732801595}, 345.1011337999953)
DEBUG flower 2022–11–28 21:18:30,292 | server.py:157 | evaluate_round: strategy sampled 2 clients (out of 2)
…
INFO flower 2022–11–28 21:18:30,609 | app.py:181 | app_fit: metrics_centralized {'accuracy': [(0, 0.8384845463609173), (1, 0.8414755732801595), (2, 0.8215353938185443), (3, 0.8295114656031904), (4, 0.8155533399800599), (5, 0.8384845463609173), (6, 0.8305084745762712), (7, 0.8305084745762712), (8, 0.8245264207377866), (9, 0.8325024925224327), (10, 0.8265204386839482)]}
Open another two terminals and run the commands below individually in different terminals:
$ python client1.py
$ python client2.py
Expect output:
INFO flower 2022–11–28 21:12:45,088 | connection.py:102 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flower 2022–11–28 21:12:45,103 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flower 2022–11–28 21:12:45,103 | connection.py:39 | ChannelConnectivity.CONNECTING
DEBUG flower 2022–11–28 21:12:45,103 | connection.py:39 | ChannelConnectivity.READY
Training finished for round 1
…
Training finished for round 10
DEBUG flower 2022–11–28 21:18:30,656 | connection.py:121 | gRPC channel closed
INFO flower 2022–11–28 21:18:30,656 | app.py:101 | Disconnect and shut down
You have just finished building and running a federated machine learning application. You can now play around with the number of training rounds and see how your models perform.
It is important to note that the model will improve with an increase in the number of training rounds.
Posted on January 20, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.