16 августа 2022
Вот структуры:
где channels == 3
def build_generator(self):
# функция для создания блока CNN, увеличивающего размеры изображения
def add_generator_block(x_block, filters):
block = Conv2DTranspose(filters, filter_size, strides=2, padding='same')(x_block)
block = BatchNormalization()(block)
block = LeakyReLU(0.3)(block)
return block
start_filters = 16
filter_size = [5, 5]
latent_dim = 100
# вход - это вектор шума
inp = Input(shape=(latent_dim,))
# проекция вектора шума в тензор с такой же размерностью,
# как последний сверточный слой дискриминатора
x = Dense(4 * 4 * (start_filters * 8), input_dim=latent_dim)(inp)
x = BatchNormalization()(x)
x = Reshape(target_shape=(4, 4, start_filters * 8))(x)
# строим генератор для увеличения изображения в 4 раза
x = add_generator_block(x, start_filters * 4)
x = add_generator_block(x, start_filters * 2)
x = add_generator_block(x, start_filters)
x = add_generator_block(x, start_filters)
# превращаем вывод в трехмерный тензор, изображение с 3 каналами
x = Conv2D(self.channels, kernel_size=5, padding='same', activation='tanh')(x)
model = Model(inputs=inp, outputs=x)
print('generator')
model.summary()
return model
def build_discriminator(self):
# функция создания блока CNN block для уменьшения размера изображения
def add_discriminator_block(x_block, filters):
block = Conv2D(filters, filter_size, padding='same')(x_block)
block = BatchNormalization()(block)
block = Conv2D(filters, filter_size, padding='same', strides=2)(block)
block = BatchNormalization()(block)
block = LeakyReLU(0.3)(block)
return block
start_filters = 16
filter_size = [5, 5]
#inp = Input(shape=(self.img_rows, self.img_cols, self.channels))
inp = Input(shape=(64, 64, self.channels))
# строим дискриминатор для уменьшения изображения
x = add_discriminator_block(inp, start_filters)
x = add_discriminator_block(x, start_filters * 2)
x = add_discriminator_block(x, start_filters * 4)
x = add_discriminator_block(x, start_filters * 8)
# усреднение и возврат бинарного вывода
x = GlobalAveragePooling2D()(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inp, outputs=x)
print('discriminator')
model.summary()
return model


Ответить
Пожаловаться