Privacy-Preserving Machine Learning with AIJack - 2: Model Inversion Attack against Federated Learning on PyTorch
Koukyosyumei
Posted on January 4, 2023
This post is part of our Privacy-Preserving Machine Learning with AIJack series.
- Part 1: Federated Learning
- Part 2: Model Inversion Attack against Federated Learning
- Part 3: Federated Learning with Homomorphic Encryption
- Part 4: Federated Learning with Differential Privacy
- Part 5: Federated Learning with Sparse Gradient
- Part 6: Poisoning Attack against Federated Learning
- Part 7: Federated Learning with FoolsGold
- Part 8: Split Learning
- Part 9: Label Leakage against Split Learning
Overview
Although Federated Learning allows clients to hide their private datasets, many papers [1, 2, 3] show that the malicious server can recover private training samples from the uploaded local gradient .
Since the server already knows the parameters of the global model , the server can estimate the private training sample with the following optimization.
, where
In other words, this attack tries to reconstruct the private training data by optimizing the fake data to generate gradients close enough to the received gradients from the client.
Code
Although many works propose various distance metrics, regularization terms, and optimization methods, AIJack supports many popular components.
First, we need to import the necessary libraries.
import cv2
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from numpy import e
from matplotlib import pyplot as plt
import torch.optim as optim
from tqdm.notebook import tqdm
from aijack.collaborative.fedavg import FedAVGAPI, FedAVGClient, FedAVGServer
from aijack.attack.inversion import GradientInversionAttackServerManager
from torch.utils.data import DataLoader, TensorDataset
from aijack.utils import NumpyDataset
We use LeNet and MNIST for demonstration purpose.
class LeNet(nn.Module):
def __init__(self, channel=3, hideen=768, num_classes=10):
super(LeNet, self).__init__()
act = nn.Sigmoid
self.body = nn.Sequential(
nn.Conv2d(channel, 12, kernel_size=5, padding=5 // 2, stride=2),
nn.BatchNorm2d(12),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
nn.BatchNorm2d(12),
act(),
nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
nn.BatchNorm2d(12),
act(),
)
self.fc = nn.Sequential(nn.Linear(hideen, num_classes))
def forward(self, x):
out = self.body(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def prepare_dataloader(path="MNIST/.", batch_size=64, shuffle=True):
at_t_dataset_train = torchvision.datasets.MNIST(
root=path, train=True, download=True
)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
dataset = NumpyDataset(
at_t_dataset_train.train_data.numpy(),
at_t_dataset_train.train_labels.numpy(),
transform=transform,
)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0
)
return dataloader
The hyper-parameters are as follows:
torch.manual_seed(7777)
shape_img = (28, 28)
num_classes = 10
channel = 1
hidden = 588
criterion = nn.CrossEntropyLoss()
num_seeds = 5
We will try to recover the below data.
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
dataloader = prepare_dataloader()
for data in dataloader:
xs, ys = data[0], data[1]
break
x = xs[:1]
y = ys[:1]
fig = plt.figure(figsize=(1, 1))
plt.axis("off")
plt.imshow(x.detach().numpy()[0][0], cmap="gray")
plt.show()
Like Part 1, we can easily implement Federated Learning with AIJack. One big difference is that we wrap FedAVGServer
class with GradientInversionAttackServerManager
, so the server can execute gradient-based model inversion attack. This manager class makes the server estimate the private data from the uploaded gradient in each communication. We attack five times with different random seeds.
manager = GradientInversionAttackServerManager(
(1, 28, 28),
num_trial_per_communication=num_seeds,
log_interval=0,
num_iteration=100,
distancename="l2",
device=device,
gradinvattack_kwargs={"lr": 1.0},
)
DLGFedAVGServer = manager.attach(FedAVGServer)
client = FedAVGClient(
LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
lr=1.0,
device=device,
)
server = DLGFedAVGServer(
[client],
LeNet(channel=channel, hideen=hidden, num_classes=num_classes).to(device),
lr=1.0,
device=device,
)
local_dataloaders = [DataLoader(TensorDataset(x, y))]
local_optimizers = [optim.SGD(client.parameters(), lr=1.0)]
api = FedAVGAPI(
server,
[client],
criterion,
local_optimizers,
local_dataloaders,
num_communication=1,
local_epoch=1,
use_gradients=True,
device=device,
)
api.run()
Then, we can confirm that the attacker can successfully recover the original private image with all random seeds.
fig = plt.figure(figsize=(5, 2))
for s, result in enumerate(server.attack_results[0]):
ax = fig.add_subplot(1, len(server.attack_results[0]), s + 1)
ax.imshow(result[0].cpu().detach().numpy()[0][0], cmap="gray")
ax.axis("off")
plt.tight_layout()
plt.show()
Summary
This tutorial taught us that Federated Learning is unsafe since the server can steal private training data from the received gradients. You can see more examples of Model Inversion Attacks against Federated Learning in AIJack's document. To prevent this attack, the following tutorial introduces Federated Learning with Homomorphic Encryption, where each client encrypts its local gradients before uploading.
Reference
[1] Zhu, Ligeng, Zhijian Liu, and Song Han. "Deep leakage from gradients." Advances in neural information processing systems 32 (2019).
[2] Zhao, Bo, Konda Reddy Mopuri, and Hakan Bilen. "idlg: Improved deep leakage from gradients." arXiv preprint arXiv:2001.02610 (2020).
[3] Yin, Hongxu, et al. "See through gradients: Image batch recovery via gradinversion." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
Posted on January 4, 2023
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.