55 lines
2.5 KiB
Python
55 lines
2.5 KiB
Python
from keras.engine.topology import Layer, InputSpec
|
|
import keras.utils.conv_utils as conv_utils
|
|
import tensorflow as tf
|
|
import keras.backend as K
|
|
|
|
def normalize_data_format(value):
|
|
if value is None:
|
|
value = K.image_data_format()
|
|
data_format = value.lower()
|
|
if data_format not in {'channels_first', 'channels_last'}:
|
|
raise ValueError('The `data_format` argument must be one of '
|
|
'"channels_first", "channels_last". Received: ' +
|
|
str(value))
|
|
return data_format
|
|
|
|
|
|
class BilinearUpSampling2D(Layer):
|
|
def __init__(self, size=(2, 2), data_format=None, **kwargs):
|
|
super(BilinearUpSampling2D, self).__init__(**kwargs)
|
|
self.data_format = normalize_data_format(data_format)
|
|
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
|
self.input_spec = InputSpec(ndim=4)
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
if self.data_format == 'channels_first':
|
|
height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
|
width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
|
return (input_shape[0],
|
|
input_shape[1],
|
|
height,
|
|
width)
|
|
elif self.data_format == 'channels_last':
|
|
height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
|
width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
|
return (input_shape[0],
|
|
height,
|
|
width,
|
|
input_shape[3])
|
|
|
|
def call(self, inputs):
|
|
input_shape = K.shape(inputs)
|
|
if self.data_format == 'channels_first':
|
|
height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
|
width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
|
elif self.data_format == 'channels_last':
|
|
height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
|
width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
|
|
|
return tf.image.resize(inputs, [height, width], method=tf.image.ResizeMethod.BILINEAR)
|
|
|
|
def get_config(self):
|
|
config = {'size': self.size, 'data_format': self.data_format}
|
|
base_config = super(BilinearUpSampling2D, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|