Add class information to keras network
Suggested Solution
Reusing the code from the repository you shared, here are some suggested modifications to train a classifier along your generator and discriminator (their architectures and other losses are left untouched):
from keras import backend as Kfrom keras.models import Sequentialfrom keras.layers.core import Dense, Dropout, Activation, Flattenfrom keras.layers.convolutional import Convolution2D, MaxPooling2Ddef lenet_classifier_model(nb_classes): # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5 # Replace with your favorite classifier... model = Sequential() model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(180, activation='relu', init='he_normal')) model.add(Dropout(0.5)) model.add(Dense(100, activation='relu', init='he_normal')) model.add(Dropout(0.5)) model.add(Dense(nb_classes, activation='softmax', init='he_normal'))def generator_containing_discriminator_and_classifier(generator, discriminator, classifier): inputs = Input((IN_CH, img_cols, img_rows)) x_generator = generator(inputs) merged = merge([inputs, x_generator], mode='concat', concat_axis=1) discriminator.trainable = False x_discriminator = discriminator(merged) classifier.trainable = False x_classifier = classifier(x_generator) model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier]) return modeldef train(BATCH_SIZE): (X_train, Y_train, LABEL_train) = get_data('train') # replace with your data here X_train = (X_train.astype(np.float32) - 127.5) / 127.5 Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5 discriminator = discriminator_model() generator = generator_model() classifier = lenet_classifier_model(6) generator.summary() discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier( generator, discriminator, classifier) d_optim = Adagrad(lr=0.005) g_optim = Adagrad(lr=0.005) generator.compile(loss='mse', optimizer="rmsprop") discriminator_and_classifier_on_generator.compile( loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"], optimizer="rmsprop") discriminator.trainable = True discriminator.compile(loss=discriminator_loss, optimizer="rmsprop") classifier.trainable = True classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop") for epoch in range(100): print("Epoch is", epoch) print("Number of batches", int(X_train.shape[0] / BATCH_SIZE)) for index in range(int(X_train.shape[0] / BATCH_SIZE)): image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE] # replace with your data here generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]) if index % 20 == 0: image = combine_images(generated_images) image = image * 127.5 + 127.5 image = np.swapaxes(image, 0, 2) cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image) # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png") # Training D: real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch), axis=1) fake_pairs = np.concatenate( (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1) X = np.concatenate((real_pairs, fake_pairs)) y = np.zeros((20, 1, 64, 64)) # [1] * BATCH_SIZE + [0] * BATCH_SIZE d_loss = discriminator.train_on_batch(X, y) print("batch %d d_loss : %f" % (index, d_loss)) discriminator.trainable = False # Training C: c_loss = classifier.train_on_batch(image_batch, label_batch) print("batch %d c_loss : %f" % (index, c_loss)) classifier.trainable = False # Train G: g_loss = discriminator_and_classifier_on_generator.train_on_batch( X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], [image_batch, np.ones((10, 1, 64, 64)), label_batch]) discriminator.trainable = True classifier.trainable = True print("batch %d g_loss : %f" % (index, g_loss[1])) if index % 20 == 0: generator.save_weights('generator', True) discriminator.save_weights('discriminator', True)
Theoretical Details
I believe there are some misunderstandings regarding how conditional GANs work and what is the discriminators role in such schemes.
Role of the Discriminator
In the min-max game which is GAN training [4], the discriminator D
is playing against the generator G
(the network you actually care about) so that under D
's scrutiny, G
becomes better at outputting realistic results.
For that, D
is trained to tell apart real samples from samples from G
; while G
is trained to fool D
by generating realistic results / results following the target distribution.
Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain
A
(e.g. real picture) to another domainB
(e.g. sketch),D
is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample fromA
+ corresponding target sample fromB
) and "fake" pairs (input sample fromA
+ corresponding output fromG
) [1, 2]
Training a conditional generator against D
(as opposed to simply training G
alone, with a L1/L2 loss only e.g. DAE) improves the sampling capability of G
, forcing it to output crisp, realistic results instead of trying to average the distribution.
Even though discriminators can have multiple sub-networks to cover other tasks (see next paragraphs), D
should keep at least one sub-network/output to cover its main task: telling real samples from generated ones apart. Asking D
to regress further semantic information (e.g. classes) alongside may interfere with this main purpose.
Note:
D
output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.
Conditional GANs
Traditional GANs are trained in an unsupervised manner to generate realistic data (e.g. images) from a random noise vector as input. [4]
As previously mentioned, conditional GANs have further input conditions. Along/instead of the noise vector, they take for input a sample from a domain A
and return a corresponding sample from a domain B
. A
can be a completely different modality, e.g. B = sketch image
while A = discrete label
; B = volumetric data
while A = RGB image
, etc. [3]
Such GANs can also be conditioned by multiples inputs, e.g. A = real image + discrete label
while B = sketch image
. A famous work introducing such methods is InfoGAN [5]. It presents how to condition GANs on multiple continuous or discrete inputs (e.g. A = digit class + writing type
, B = handwritten digit image
), using a more advanced discriminator which has for 2nd task to force G
to maximize the mutual-information between its conditioning inputs and its corresponding outputs.
Maximizing the Mutual Information for cGANs
InfoGAN discriminator has 2 heads/sub-networks to cover its 2 tasks [5]:
- One head
D1
does the traditional real/generated discrimination --G
has to minimize this result, i.e. it has to foolD1
so that it can't tell apart real form generated data; - Another head
D2
(also namedQ
network) tries to regress the inputA
information --G
has to maximize this result, i.e. it has to output data which "show" the requested semantic information (c.f. mutual-information maximization betweenG
conditional inputs and its outputs).
You can find a Keras implementation here for instance: https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan.
Several works are using similar schemes to improve control over what a GAN is generating, by using provided labels and maximizing the mutual information between these inputs and G
outputs [6, 7]. The basic idea is always the same though:
- Train
G
to generate elements of domainB
, given some inputs of domain(s)A
; - Train
D
to discriminate "real"/"fake" results --G
has to minimize this; - Train
Q
(e.g. a classifier ; can share layers withD
) to estimate the originalA
inputs fromB
samples --G
has to maximize this).
Wrapping Up
In your case, it seems you have the following training data:
- real images
Ia
- corresponding sketch images
Ib
- corresponding class labels
c
And you want to train a generator G
so that given an image Ia
and its class label c
, it outputs a proper sketch image Ib'
.
All in all, that's a lot of information you have, and you can supervise your training both on the conditioned images and the conditioned labels...Inspired from the aforementioned methods [1, 2, 5, 6, 7], here is a possible way of using all this information to train your conditional G
:
G
:- Inputs:
Ia
+c
- Output:
Ib'
- Architecture: up-to-you (e.g. U-Net, ResNet, ...)
- Losses: L1/L2 loss between
Ib'
&Ib
,-D
loss,Q
loss
D
:- Inputs:
Ia
+Ib
(real pair),Ia
+Ib'
(fake pair) - Output: "fakeness" scalar/matrix
- Architecture: up-to-you (e.g. PatchGAN)
- Loss: cross-entropy on the "fakeness" estimation
Q
:- Inputs:
Ib
(real sample, for trainingQ
),Ib'
(fake sample, when back-propagating throughG
) - Output:
c'
(estimated class) - Architecture: up-to-you (e.g. LeNet, ResNet, VGG, ...)
- Loss: cross-entropy between
c
andc'
- Train
D
on a batch of real pairsIa
+Ib
then on a batch of fake pairsIa
+Ib'
; - Train
Q
on a batch of real samplesIb
; - Fix
D
andQ
weights; - Train
G
, passing its generated outputsIb'
toD
andQ
to back-propagate through them.
Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.
References
- Isola, Phillip, et al. "Image-to-image translation with conditional adversarial networks." arXiv preprint (2017). http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
- Zhu, Jun-Yan, et al. "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv preprint arXiv:1703.10593 (2017). http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
- Mirza, Mehdi, and Simon Osindero. "Conditional generative adversarial nets." arXiv preprint arXiv:1411.1784 (2014). https://arxiv.org/pdf/1411.1784
- Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
- Chen, Xi, et al. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." Advances in Neural Information Processing Systems. 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
- Lee, Minhyeok, and Junhee Seok. "Controllable Generative Adversarial Network." arXiv preprint arXiv:1708.00598 (2017). https://arxiv.org/pdf/1708.00598.pdf
- Odena, Augustus, Christopher Olah, and Jonathon Shlens. "Conditional image synthesis with auxiliary classifier gans." arXiv preprint arXiv:1610.09585 (2016). http://proceedings.mlr.press/v70/odena17a/odena17a.pdf
You should modify your discriminator model, either to have two outputs, or to have a "n_classes + 1" output.
Warning: I don't see in the definition of your discriminator it outputting 'true/false', I see it outputting an image...
Somewhere it should contain a GlobalMaxPooling2D
or an GlobalAveragePooling2D
.
At the end and one or more Dense
layers for classification.
If telling true/false, the last Dense should have 1 unit.
Otherwise n_classes + 1
units.
So, the ending of your discriminator should be something like
...GlobalMaxPooling2D()......Dense(someHidden,...)......Dense(n_classes+1,...)...
The discriminator will now output n_classes
plus either a "true/fake" sign (you will not be able to use "categorical" there) or even a "fake class" (then you zero the other classes and use categorical)
Your generates sketches should be passes to the discriminator along with a target that will be the concatenation of the fake class with the other class.
Option 1 - Using the "true/fake" sign. (Don't use "categorical_crossentropy")
#true sketches into discriminator:fakeClass = np.zeros((total_samples,))sketchClass = originalClassestargetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)#fake sketches into discriminator:fakeClass = np.ones((total_fake_sketches))sketchClass = originalClassestargetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
Option 2 - Using the "fake class" (can use "categorical_crossentropy"):
#true sketches into discriminator:fakeClass = np.zeros((total_samples,))sketchClass = originalClassestargetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)#fake sketches into discriminator:fakeClass = np.ones((total_fake_sketches))sketchClass = np.zeros((total_fake_sketches, n_classes))targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)
Now concatenate everything into a single target array (respective to the input sketches)
Updated training method
For this training method, your loss function should be one of:
discriminator.compile(loss='binary_crossentropy', optimizer=....)
discriminator.compile(loss='categorical_crossentropy', optimizer=...)
Code:
for epoch in range(100): print("Epoch is", epoch) print("Number of batches", int(X_train.shape[0]/BATCH_SIZE)) for index in range(int(X_train.shape[0]/BATCH_SIZE)): #names: #images -> initial images, not changed #sketches -> generated + true sketches #classes -> your classification for the images #isGenerated -> the output of your discriminator telling whether the passed sketches are fake batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE) trueImages = X_train[batchSlice] trueSketches = Y_train[batchSlice] trueClasses = originalClasses[batchSlice] trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1) trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1) fakeSketches = generator.predict(trueImages) fakeClasses = originalClasses[batchSlize] #if option 1 -> telling class + isGenerated - use "binary_crossentropy" fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy" fakeIsGenerated = np.ones((len(fakeSketches),)) fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1) allSketches = np.concatenate([trueSketches,fakeSketches],axis=0) allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0) d_loss = discriminator.train_on_batch(allSketches, allEndTargets) pred_temp = discriminator.predict(allSketches) #print(np.shape(pred_temp)) print("batch %d d_loss : %f" % (index, d_loss)) ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models. #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again discriminator.trainable = False g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets) discriminator.trainable = True print("batch %d g_loss : %f" % (index, g_loss[1])) if index % 20 == 0: generator.save_weights('generator', True) discriminator.save_weights('discriminator', True)
Compiling the models properly
When you create "discriminator" and "discriminator_on_generator":
discriminator.trainable = Truefor l in discriminator.layers: l.trainable = Truediscriminator.compile(.....)for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]: l.trainable = Falsediscriminator_on_generator.compile(....)