Add class information to keras network Add class information to keras network python python

Add class information to keras network

Suggested Solution

Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):

from keras import backend as Kfrom keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Activation, Flattenfrom keras.layers.convolutional import Convolution2D, MaxPooling2Ddef lenet_classifier_model(nb_classes):    # Snipped by Fabien Tanc -    # Replace with your favorite classifier...    model = Sequential()    model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))    model.add(MaxPooling2D(pool_size=(2, 2)))    model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))    model.add(MaxPooling2D(pool_size=(2, 2)))    model.add(Flatten())    model.add(Dense(180, activation='relu', init='he_normal'))    model.add(Dropout(0.5))    model.add(Dense(100, activation='relu', init='he_normal'))    model.add(Dropout(0.5))    model.add(Dense(nb_classes, activation='softmax', init='he_normal'))def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):    inputs = Input((IN_CH, img_cols, img_rows))    x_generator = generator(inputs)    merged = merge([inputs, x_generator], mode='concat', concat_axis=1)    discriminator.trainable = False    x_discriminator = discriminator(merged)    classifier.trainable = False    x_classifier = classifier(x_generator)    model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])    return modeldef train(BATCH_SIZE):    (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here    X_train = (X_train.astype(np.float32) - 127.5) / 127.5    Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5    discriminator = discriminator_model()    generator = generator_model()    classifier = lenet_classifier_model(6)    generator.summary()    discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(        generator, discriminator, classifier)    d_optim = Adagrad(lr=0.005)    g_optim = Adagrad(lr=0.005)    generator.compile(loss='mse', optimizer="rmsprop")    discriminator_and_classifier_on_generator.compile(        loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],        optimizer="rmsprop")    discriminator.trainable = True    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")    classifier.trainable = True    classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")    for epoch in range(100):        print("Epoch is", epoch)        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))        for index in range(int(X_train.shape[0] / BATCH_SIZE)):            image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]            label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here            generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])            if index % 20 == 0:                image = combine_images(generated_images)                image = image * 127.5 + 127.5                image = np.swapaxes(image, 0, 2)                cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)                # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")            # Training D:            real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),                                        axis=1)            fake_pairs = np.concatenate(                (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)            X = np.concatenate((real_pairs, fake_pairs))            y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE            d_loss = discriminator.train_on_batch(X, y)            print("batch %d d_loss : %f" % (index, d_loss))            discriminator.trainable = False            # Training C:            c_loss = classifier.train_on_batch(image_batch, label_batch)            print("batch %d c_loss : %f" % (index, c_loss))            classifier.trainable = False            # Train G:            g_loss = discriminator_and_classifier_on_generator.train_on_batch(                X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :],                 [image_batch, np.ones((10, 1, 64, 64)), label_batch])            discriminator.trainable = True            classifier.trainable = True            print("batch %d g_loss : %f" % (index, g_loss[1]))            if index % 20 == 0:                generator.save_weights('generator', True)                discriminator.save_weights('discriminator', True)

Theoretical Details

I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.

Role of the Discriminator

In the min-max game which is GAN training [4], the discriminator D is playing against the generator G (the network you actually care about) so that under D's scrutiny, G becomes better at outputting realistic results.

For that, D is trained to tell apart real samples from samples from G ; while G is trained to fool D by generating realistic results / results following the target distribution.

Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain A (e.g. real picture) to another domain B (e.g. sketch), D is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample from A + corresponding target sample from B) and "fake" pairs (input sample from A + corresponding output from G) [1, 2]

Training a conditional generator against D (as opposed to simply training G alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G, forcing it to output crisp, realistic results instead of trying to average the distribution.

Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.

Note: D output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.

Conditional GANs

Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]

As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A and return a corresponding sample from a domain B. A can be a completely different modality, e.g. B = sketch image while A = discrete label ; B = volumetric data while A = RGB image, etc. [3]

Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label while B = sketch image. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type, B = handwritten digit image), using a more advanced discriminator which has for 2nd task to force G to maximize the mutual-information between its conditioning inputs and its corresponding outputs.

Maximizing the Mutual Information for cGANs

InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:

  • One head D1 does the traditional real/generated discrimination -- G has to minimize this result, i.e. it has to fool D1 so that it can't tell apart real form generated data;
  • Another head D2 (also named Q network) tries to regress the input A information -- G has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization between G conditional inputs and its outputs).

You can find a Keras implementation here for instance:

Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G outputs [6, 7]. The basic idea is always the same though:

  • Train G to generate elements of domain B, given some inputs of domain(s) A;
  • Train D to discriminate "real"/"fake" results -- G has to minimize this;
  • Train Q (e.g. a classifier ; can share layers with D) to estimate the original A inputs from B samples -- G has to maximize this).

Wrapping Up

In your case, it seems you have the following training data:

  • real images Ia
  • corresponding sketch images Ib
  • corresponding class labels c

And you want to train a generator G so that given an image Ia and its class label c, it outputs a proper sketch image Ib'.

All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels...Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G:

Network G:
  • Inputs: Ia + c
  • Output: Ib'
  • Architecture: up-to-you (e.g. U-Net, ResNet, ...)
  • Losses: L1/L2 loss between Ib' & Ib, -D loss, Q loss
Network D:
  • Inputs: Ia + Ib (real pair), Ia + Ib' (fake pair)
  • Output: "fakeness" scalar/matrix
  • Architecture: up-to-you (e.g. PatchGAN)
  • Loss: cross-entropy on the "fakeness" estimation
Network Q:
  • Inputs: Ib (real sample, for training Q), Ib' (fake sample, when back-propagating through G)
  • Output: c' (estimated class)
  • Architecture: up-to-you (e.g. LeNet, ResNet, VGG, ...)
  • Loss: cross-entropy between c and c'
Training Phase:
  1. Train D on a batch of real pairs Ia + Ib then on a batch of fake pairs Ia + Ib';
  2. Train Q on a batch of real samples Ib;
  3. Fix D and Q weights;
  4. Train G, passing its generated outputs Ib' to D and Q to back-propagate through them.

Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.


  1. Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017).
  2. Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017).
  3. Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014).
  4. Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.
  5. Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016.
  6. Lee, Minhyeok, and Junhee Seok. "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598 (2017).
  7. Odena, Augustus, Christopher Olah, and Jonathon Shlens. "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585 (2016).

You should modify your discriminator model, either to have two outputs, or to have a "n_classes + 1" output.

Warning: I don't see in the definition of your discriminator it outputting 'true/false', I see it outputting an image...

Somewhere it should contain a GlobalMaxPooling2D or an GlobalAveragePooling2D.
At the end and one or more Dense layers for classification.

If telling true/false, the last Dense should have 1 unit.
Otherwise n_classes + 1 units.

So, the ending of your discriminator should be something like


The discriminator will now output n_classes plus either a "true/fake" sign (you will not be able to use "categorical" there) or even a "fake class" (then you zero the other classes and use categorical)

Your generates sketches should be passes to the discriminator along with a target that will be the concatenation of the fake class with the other class.

Option 1 - Using the "true/fake" sign. (Don't use "categorical_crossentropy")

#true sketches into discriminator:fakeClass = np.zeros((total_samples,))sketchClass = originalClassestargetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)#fake sketches into discriminator:fakeClass = np.ones((total_fake_sketches))sketchClass = originalClassestargetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

Option 2 - Using the "fake class" (can use "categorical_crossentropy"):

#true sketches into discriminator:fakeClass = np.zeros((total_samples,))sketchClass = originalClassestargetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)#fake sketches into discriminator:fakeClass = np.ones((total_fake_sketches))sketchClass = np.zeros((total_fake_sketches, n_classes))targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

Now concatenate everything into a single target array (respective to the input sketches)

Updated training method

For this training method, your loss function should be one of:

  • discriminator.compile(loss='binary_crossentropy', optimizer=....)
  • discriminator.compile(loss='categorical_crossentropy', optimizer=...)


for epoch in range(100):    print("Epoch is", epoch)    print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))    for index in range(int(X_train.shape[0]/BATCH_SIZE)):        #names:            #images -> initial images, not changed                #sketches -> generated + true sketches                #classes -> your classification for the images                #isGenerated -> the output of your discriminator telling whether the passed sketches are fake        batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)        trueImages = X_train[batchSlice]        trueSketches = Y_train[batchSlice]         trueClasses = originalClasses[batchSlice]        trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)        trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)        fakeSketches = generator.predict(trueImages)        fakeClasses = originalClasses[batchSlize]             #if option 1 -> telling class + isGenerated - use "binary_crossentropy"        fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"            fakeIsGenerated = np.ones((len(fakeSketches),))        fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)        allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)                    allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)        d_loss = discriminator.train_on_batch(allSketches, allEndTargets)        pred_temp = discriminator.predict(allSketches)        #print(np.shape(pred_temp))        print("batch %d d_loss : %f" % (index, d_loss))        ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models.             #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again           discriminator.trainable = False        g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)        discriminator.trainable = True        print("batch %d g_loss : %f" % (index, g_loss[1]))        if index % 20 == 0:            generator.save_weights('generator', True)            discriminator.save_weights('discriminator', True)

Compiling the models properly

When you create "discriminator" and "discriminator_on_generator":

discriminator.trainable = Truefor l in discriminator.layers:    l.trainable = Truediscriminator.compile(.....)for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:    l.trainable = Falsediscriminator_on_generator.compile(....)