Transform WGAN into cGAN

mathys_daviet_9cd41ace98a

Mathys Daviet

Posted on January 22, 2024

Transform WGAN into cGAN

"I am working on developing a Generative Adversarial Network (GAN) with the aim of generating new microstructures based on their characteristics. The objective is to create a microstructure using a given characteristic, provided to the GAN in vector form. This process is implemented using a database containing 40,000 pairs of microstructures and their corresponding characteristics. I have already coded a Wasserstein GAN (WGAN) that successfully generates coherent microstructures from the database, although it currently lacks a connection to the specified characteristics. Additionally, I have coded a conditional GAN (cGAN) that operates on the MNIST dataset. However, I require your assistance in merging these two code structures. Thank you very much for any help you can provide!

##cGAN
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm

# Common config
batch_size = 64

# Generator config
sample_size = 100  # Random sample size
g_alpha = 0.01  # LeakyReLU alpha
g_lr = 1.0e-4  # Learning rate

# Discriminator config
d_alpha = 0.01  # LeakyReLU alpha
d_lr = 1.0e-4  # Learning rate

# Data Loader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# Coverts conditions into feature vectors
class Condition(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # From one-hot encoding to features: 10 => 784
        self.fc = nn.Sequential(
            nn.Linear(10, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

    def forward(self, labels: torch.Tensor):
        # One-hot encode labels
        x = F.one_hot(labels, num_classes=10)

        # From Long to Float
        x = x.float()

        # To feature vectors
        return self.fc(x)


# Reshape helper
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()

        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)


# Generator network
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1,
                               kernel_size=5, stride=2, padding=2,
                               output_padding=1, bias=False),
            nn.Sigmoid())

        # Random value sample size
        self.sample_size = sample_size

        # To convert labels into feature vectors
        self.cond = Condition(alpha)

    def forward(self, labels: torch.Tensor):
        # Labels as feature vectors
        c = self.cond(labels)

        # Batch size is the number of labels
        batch_size = len(labels)

        # Generate random inputs
        z = torch.randn(batch_size, self.sample_size)

        # Inputs are the sum of random inputs and label features
        x = self.fc(z)  # => 784
        x = self.reshape(x + c)  # => 16 x 7 x 7
        x = self.conv1(x)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 1 x 28 x 28
        return x


# Discriminator network
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16,
                      kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

        # Reshape label features: 784 => 16 x 7 x 7
        self.cond = nn.Sequential(
            Condition(alpha),
            Reshape(16, 7, 7))

    def forward(self, images: torch.Tensor,
                labels: torch.Tensor,
                targets: torch.Tensor):
        # Label features
        c = self.cond(labels)

        # Image features + Label features => real or fake?
        x = self.conv1(images)  # => 32 x 14 x 14
        x = self.conv2(x)  # => 16 x 7 x 7
        prediction = self.fc(x + c)  # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# To save grid images
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)  # Into a grid
    image_grid = image_grid.permute(1, 2, 0)  # Channel to last
    image_grid = image_grid.cpu().numpy()  # Into Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# Real / Fake targets
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)

# Generator and discriminator
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)

# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)

# Train loop
for epoch in range(100):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        # ===============================
        # Disciminator Network Training
        # ===============================

        # Images from MNIST are considered as real
        d_loss = discriminator(images, labels, real_targets)

        # Images from Generator are considered as fake
        d_loss += discriminator(generator(labels), labels, fake_targets)

        # Discriminator paramter update
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============================
        # Generator Network Training
        # ===============================

        # Images from Generator should be as real as ones from MNIST
        g_loss = discriminator(generator(labels), labels, real_targets)

        # Generator parameter update
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print loss
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save generated images
    labels = torch.LongTensor(list(range(10))).repeat(8).flatten()
    save_image_grid(epoch, generator(labels), ncol=10)

##WGAN
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.layers import Dropout
# from tensorflow.keras.constraints import ClipConstraint

import numpy as np
import pandas as pd
from utils import load_parquet_files, wasserstein_loss, generator_loss, ClipConstraint, generate_batches, Conv2DCircularPadding
import matplotlib.pyplot as plt
import os
from vtk import vtkStructuredPoints, vtkXMLImageDataWriter
import vtk
from vtkmodules.util import numpy_support
# from sklearn.model_selection import train_test_split

class Generator(models.Model):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.model = self.build_model()

    # def build_model(self):
    #     model = models.Sequential()
    #     model.add(layers.Dense(16 * 16 * 256, input_dim=self.latent_dim, activation='relu'))
    #     model.add(layers.Reshape((16, 16, 256)))  # La taille avant la convolution transposee
    #     model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     model.add(layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     model.add(layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     model.add(layers.BatchNormalization())  # Batch normalization
    #     # model.add(layers.Conv2DTranspose(16, kernel_size=4, strides=2, padding='same', activation='relu'))
    #     # model.add(layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='hard_sigmoid'))
    #     model.add(layers.Conv2DTranspose(1, kernel_size=4, strides=2, padding='same', activation='sigmoid'))
    #     return model

    def build_model(self):
        model = models.Sequential()
        model.add(layers.Dense(8 * 8 * 512, input_dim=self.latent_dim, activation='relu'))
        model.add(layers.Reshape((8, 8, 512)))
        model.add(layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(Dropout(0.25))
        model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(Dropout(0.25))
        model.add(layers.UpSampling2D())
        model.add(layers.Conv2D(32, kernel_size=3, padding='same', activation='relu'))
        model.add(layers.BatchNormalization())
        model.add(layers.UpSampling2D())
        # model.add(layers.Conv2D(1, kernel_size=3, padding='same', activation='hard_sigmoid'))
        model.add(layers.Conv2D(1, kernel_size=3, padding='same', activation='sigmoid'))

        return model

    def call(self, inputs):
        return self.model(inputs)

class Discriminator(models.Model):
    def __init__(self, circ_pad):
        super(Discriminator, self).__init__()
        self.circ_pad = circ_pad
        self.model = self.build_model()

    # def build_model(self):
    #     model = models.Sequential()
    #     model.add(layers.Conv2D(64, kernel_size=4, strides=2, padding='same', kernel_constraint=ClipConstraint(0.5), input_shape=(256, 256, 1)))
    #     model.add(layers.LeakyReLU(alpha=0.2))
    #     model.add(Dropout(0.25))  # Ajout de Dropout
    #     model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding='same', kernel_constraint=ClipConstraint(0.5)))
    #     model.add(layers.LeakyReLU(alpha=0.2))
    #     model.add(Dropout(0.25))  # Ajout de Dropout
    #     model.add(layers.Flatten())
    #     model.add(layers.Dense(1, activation='linear'))
    #     return model

    # PADDING NORMAL
    def build_model(self):
        if not self.circ_pad :
            model = models.Sequential()
            model.add(layers.Conv2D(32, kernel_size=3, strides=1, padding='same', kernel_constraint=ClipConstraint(0.2), input_shape=(256, 256, 1)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(64, kernel_size=5, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(128, kernel_size=7, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Conv2D(256, kernel_size=7, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(layers.Conv2D(512, kernel_size=5, strides=2, padding='same', kernel_constraint=ClipConstraint(0.2)))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Flatten())
            model.add(layers.Dense(1, activation='linear'))
            return model

        # PADDING CIRCULAIRE
        if self.circ_pad :
            model = models.Sequential()
            model.add(Conv2DCircularPadding(32, kernel_size=3, strides=1, input_shape=(256, 256, 1)))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(64, kernel_size=5, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(128, kernel_size=7, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(Conv2DCircularPadding(256, kernel_size=9, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            model.add(Conv2DCircularPadding(512, kernel_size=5, strides=2))
            model.add(layers.LeakyReLU(alpha=0.2))
            # model.add(Dropout(0.10))
            model.add(layers.Flatten())
            model.add(layers.Dense(1, activation='linear'))
            return model

    def call(self, inputs):
        return self.model(inputs)

class GAN(models.Model):
    def __init__(self, generator, discriminator, data_path):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.compile_discriminator()
        self.compile_gan()
        self.data_path = data_path
        self.data = load_parquet_files(self.data_path, test = False)
    def compile_discriminator(self):
        self.discriminator.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.000005), metrics=['accuracy'])
        self.discriminator.trainable = False

    def compile_gan(self):
        z = layers.Input(shape=(self.generator.latent_dim,))
        fake_image = self.generator(z)
        validity = self.discriminator(fake_image)
        self.model = models.Model(z, validity)
        # self.model.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.0001))
        self.model.compile(loss=generator_loss, optimizer=RMSprop(lr=0.0005))


    def generate_latent_points(self, latent_dim, n_samples):
        x_input = np.random.randn(latent_dim * n_samples)
        x_input = x_input.reshape(n_samples, latent_dim)
        return x_input

    def generate_latent_points_uniform(self, latent_dim, n_samples):
        x_input = np.random.uniform(-1, 1, size = (n_samples, latent_dim))
        # x_input = x_input.reshape(n_samples, latent_dim)
        return x_input
    def generate_real_samples(self, n_samples):
        dfs = self.data
        dfs_array = [df.to_numpy() for df in dfs]
        # np.random.shuffle(dfs_array)

        sampled_indices = np.random.choice(len(dfs_array), size=n_samples, replace=False)

        # Sélectionner les arrays échantillonnés
        real_samples = [dfs_array[i] for i in sampled_indices]
        real_samples = np.stack(real_samples, axis=0)
        real_samples = np.expand_dims(real_samples, axis=-1)
        labels = -(np.ones((n_samples, 1)))

        return real_samples, labels

    def generate_and_save_samples(self, epoch, latent_dim, n_samples, output_dir):
        # Générer des exemples avec le générateur
        z = self.generate_latent_points(latent_dim, n_samples)
        generated_samples = self.generator.predict(z)
        # binary_generated_samples = (generated_samples > 0.5).astype(np.float32)

        for i in range(3):
             # np.save(os.path.join(output_dir, f'generated_example_epoch{epoch}_sample{i}.npy'), binary_generated_samples[i])
             np.save(os.path.join(output_dir, f'generated_example_epoch{epoch}_sample{i}.npy'), generated_samples[i])

# def train_gan(generator, discriminator, gan, latent_dim, n_epochs, n_batch, output_dir):
#     d_losses, g_losses = [], []
#     current_epoch = 0  # Ajoutez cette ligne
#
#     for epoch in range(n_epochs):
#         current_epoch += 1  # Ajoutez cette ligne
#         for _ in range(n_batch):
#             z = gan.generate_latent_points(latent_dim, n_batch)
#             X_fake = generator.predict(z)
#             # X_fake = tf.cast(X_fake > 0.5, tf.float32)
#             X_real, y_real = gan.generate_real_samples(n_samples = n_batch)
#
#             # Entraînement du discriminateur
#             d_loss_real = discriminator.train_on_batch(X_real, y_real)
#             d_loss_fake = discriminator.train_on_batch(X_fake, -np.ones((n_batch, 1)))
#             d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
#
#             # Entraînement du générateur
#             z = gan.generate_latent_points(latent_dim, n_batch)
#             y_gan = np.ones((n_batch, 1))
#             g_loss = gan.model.train_on_batch(z, y_gan)
#
#         # Enregistrement des pertes pour la visualisation
#         d_losses.append(d_loss[0])
#         g_losses.append(g_loss)
#
#         # Affichage des résultats et sauvegarde des exemples générés
#         print(f"Epoch {current_epoch}, [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}], [G loss: {g_loss}]")
#         gan.generate_and_save_samples(current_epoch, latent_dim, n_batch, output_dir)
#
#     # Affichage des courbes d'entraînement
#     plot_training_history(d_losses, g_losses)
def train_wgan(generator, discriminator, gan, latent_dim, n_epochs, n_critic, batch_size, output_dir, circ_pad):
    d_losses, g_losses = [], []
    current_epoch = 0
    # num_batches = int(np.ceil(len(gan.data) / batch_size))

    for epoch in range(n_epochs):
        # génération des batchs
        batches = generate_batches(gan.data, batch_size)
        # num_batches = len(batches)
        current_epoch += 1

        # for _ in range(batch_size):
        for batch in batches:
            # Update the critic (discriminator) multiple times
            for _ in range(n_critic):
                z = gan.generate_latent_points(latent_dim, batch_size)
                X_fake = generator.predict(z)
                # X_real, y_real = gan.generate_real_samples(n_samples=batch_size)

                # Expand dims and stacking

                # if current_epoch == 1 :
                #     print(batch[0].shape)

                # real_sample_batch = np.array([np.stack(sample, axis=0) for sample in batch])
                real_sample_batch = np.array([np.expand_dims(sample, axis = -1) for sample in batch])

                # if current_epoch == 1 :
                #     print(real_sample_batch[0].shape)

                X_real, y_real = real_sample_batch, -(np.ones((batch_size, 1)))

                d_loss_real = discriminator.train_on_batch(X_real, y_real)
                d_loss_fake = discriminator.train_on_batch(X_fake, np.ones((batch_size, 1)))  # Use +1 as the target for fake samples
                # d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                d_loss = np.mean(np.array(d_loss_fake) - np.array(d_loss_real))  # Wasserstein loss

                # Clip the weights of the discriminator
                if circ_pad :
                    for layer in discriminator.model.layers:
                        weights = layer.get_weights()
                        weights = [np.clip(w, -0.2, 0.2) for w in weights]
                        layer.set_weights(weights)

            # Update the generator
            z = gan.generate_latent_points(latent_dim, batch_size)
            y_gan = np.ones((batch_size, 1))
            g_loss = gan.model.train_on_batch(z, y_gan)

            # # Record losses for visualization
            # d_losses.append(d_loss[0])
            # g_losses.append(g_loss)

            # Record losses for visualization
            d_losses.append(-d_loss)  # Negative of Wasserstein loss for the critic
            g_losses.append(g_loss)

        # Display results and save generated samples
        # print(f"Epoch {current_epoch}, [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}], [G loss: {g_loss}]")
        print(f"Epoch {current_epoch}, [D loss: {d_losses[-1]}], [G loss: {g_loss}]")
        gan.generate_and_save_samples(current_epoch, latent_dim, batch_size, output_dir)

    # Display training curves
    plot_training_history(d_losses, g_losses)
def plot_training_history(d_losses, g_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(d_losses, label='Discriminator Loss', linestyle='--')
    plt.plot(g_losses, label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('WGAN Training History')
    plt.show()


# Choix du padding :
circ_pad = True

# Définir le nombre d'époques et la taille du lot
latent_dim = 100
n_epochs = 2
n_batch = 32
# data_path = '/Users/gregoirecolin/Documents/4A/Projet 4A/2023-06-26_projet_etudiants_mines_ML/data/preprocess_data_reduit'  # Définir le chemin d'accès aux données
data_path = '/Users/mathys/Documents/Projet 3A/preprocessed_data'  # Définir le chemin d'accès aux données
output_directory = '/Users/mathys/Documents/Projet 3A/result_WGAN'  # Remplacez par le chemin de votre choix

if not  os.path.exists(output_directory):
    os.makedirs(output_directory)

# Créer les instances des classes
generator = Generator(latent_dim)
# generator.summary()
discriminator = Discriminator(circ_pad)
# discriminator.summary()
gan = GAN(generator, discriminator, data_path)

generator.summary()
discriminator.summary()

# Entraîner le GAN
train_wgan(generator, discriminator, gan, latent_dim, n_epochs, n_critic = 1, batch_size = n_batch, output_dir = output_directory, circ_pad = circ_pad)

# Générer des exemples avec le générateur après l'entraînement
z = gan.generate_latent_points(latent_dim=latent_dim, n_samples=n_batch)
generated_samples = generator(z)
# binary_generated_samples = tf.cast(generated_samples > 0.5, tf.float32)

generator_weights = [layer.get_weights()[0].flatten() for layer in generator.layers if len(layer.get_weights()) > 0]
discriminator_weights = [layer.get_weights()[0].flatten() for layer in discriminator.layers if len(layer.get_weights()) > 0]

plt.figure(figsize=(10,5))
for weights in generator_weights :
    plt.hist(weights, bins = 50, alpha = 0.5)
plt.title('Histogramme des poids du générateur')
plt.show()

plt.figure(figsize=(10,5))
for weights in discriminator_weights :
    plt.hist(weights, bins = 50, alpha = 0.5)
plt.title('Histogramme des poids du discriminateur')
plt.show()

##utils

import os
import pandas as pd
import tensorflow as tf
import numpy as np
from tensorflow.keras import backend
from tensorflow.keras.constraints import Constraint
from tensorflow.keras import layers
def load_parquet_files(root_folder, test):
    dfs = []
    # Si on veut juste un échantillon de données
    if test :
        k = 0

    # Parcourir tous les sous-dossiers dans le chemin spécifié
    # Parcourir les dossiers dans data_path
    for folder in os.listdir(root_folder):
        folder_path = os.path.join(root_folder, folder)
        if os.path.isdir(folder_path):
            # Charger les fichiers parquet dans le dossier

            # Parcourir tous les fichiers dans le dossier
            for filename in os.listdir(folder_path):
                file_path = os.path.join(folder_path, filename)

                # Vérifier si le fichier est un fichier Parquet
                if filename.endswith(".parquet"):
                    # Charger le fichier Parquet dans un DataFrame
                    df = pd.read_parquet(file_path)

                    # Ajouter le DataFrame à la liste
                    dfs.append(df)

        if test :
            k+=1
            if k >1000 :
                break

    return dfs

def generate_batches(data, batch_size):
    data_np = [df.to_numpy() for df in data]
    np.random.shuffle(data_np)
    batches = [data_np[i:i+batch_size] for i in range(0, len(data_np), batch_size)]

    if len(batches[-1]) != batch_size :
        batches.pop()

    return batches
def wasserstein_loss(y_true, y_pred):
    # return tf.reduce_mean(y_true * y_pred)
    return backend.mean(y_true * y_pred)


def generator_loss(y_true, y_pred):
    return -tf.reduce_mean(y_pred)


class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value

    def __call__(self, weights):
        return tf.clip_by_value(weights, -self.clip_value, self.clip_value)

    def get_config(self):
        return{'clip_value': self.clip_value}

class Conv2DCircularPadding(layers.Layer):
    def __init__(self, filters, kernel_size, strides=(1, 1), activation=None, **kwargs):
        super(Conv2DCircularPadding, self).__init__(**kwargs)
        self.conv = layers.Conv2D(filters, kernel_size, strides=strides, padding='valid', activation=activation)

    def call(self, input_tensor):
        # Taille du padding basée sur la taille du kernel
        pad_size = self.conv.kernel_size[0] - 1
        half_pad = pad_size // 2

        # Padding circulaire
        padded_input = tf.concat([input_tensor[:, -half_pad:, :], input_tensor, input_tensor[:, :half_pad, :]], axis=1)
        padded_input = tf.concat([padded_input[:, :, -half_pad:], padded_input, padded_input[:, :, :half_pad]], axis=2)

        # Application de la convolution
        return self.conv(padded_input)

    def get_config(self):
        config = super(Conv2DCircularPadding, self).get_config()
        config.update({"conv": self.conv})
        return config
Enter fullscreen mode Exit fullscreen mode
💖 💪 🙅 🚩
mathys_daviet_9cd41ace98a
Mathys Daviet

Posted on January 22, 2024

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related