From 929dae88414026ac9dc760cb01d23e7d99a4a7ce Mon Sep 17 00:00:00 2001 From: Narcon Nicolas <nicolas.narcon@inrae.fr> Date: Thu, 21 Apr 2022 16:55:58 +0200 Subject: [PATCH 1/4] ENH: generate samples of same type as initial raster --- python/otbtf.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index b28a1cc4..a1cf9bd4 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -58,8 +58,11 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - buffer = gdal_ds.ReadAsArray() + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} + gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize + buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -68,7 +71,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = int(gdal_ds.RasterYSize / size_x) size_y = size_x - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) # -------------------------------------------------- Buffer class ------------------------------------------------------ @@ -244,8 +247,11 @@ class PatchesImagesReader(PatchesReaderBase): @staticmethod def _read_extract_as_np_arr(gdal_ds, offset): + gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', + 10: 'complex64', 11: 'complex128'} assert gdal_ds is not None psz = gdal_ds.RasterXSize + gdal_type = gdal_ds.GetRasterBand(1).DataType yoff = int(offset * psz) assert yoff + psz <= gdal_ds.RasterYSize buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) @@ -254,7 +260,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return np.float32(buffer) + return buffer.astype(gdal_to_np_types[gdal_type]) def get_sample(self, index): """ -- GitLab From f1bda5b1336a6a8a1d63204175e11a885b635665 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:29:51 +0200 Subject: [PATCH 2/4] REFAC: gdal_to_np_type is global constant --- python/otbtf.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index a1cf9bd4..a58d10d2 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -34,6 +34,19 @@ import tensorflow as tf from osgeo import gdal from tqdm import tqdm +# --------------------------------------------- GDAL to numpy types ---------------------------------------------------- + + +gdal_to_np_types = {1: 'uint8', + 2: 'uint16', + 3: 'int16', + 4: 'uint32', + 5: 'int32', + 6: 'float32', + 7: 'float64', + 10: 'complex64', + 11: 'complex128'} + # ----------------------------------------------------- Helpers -------------------------------------------------------- @@ -58,8 +71,6 @@ def read_as_np_arr(gdal_ds, as_patches=True): False, the shape is (1, psz_y, psz_x, nb_channels) :return: Numpy array of dim 4 """ - gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', - 10: 'complex64', 11: 'complex128'} gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) @@ -247,8 +258,6 @@ class PatchesImagesReader(PatchesReaderBase): @staticmethod def _read_extract_as_np_arr(gdal_ds, offset): - gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64', - 10: 'complex64', 11: 'complex128'} assert gdal_ds is not None psz = gdal_ds.RasterXSize gdal_type = gdal_ds.GetRasterBand(1).DataType -- GitLab From ae46d5e19cc49714b941bd749f3d876a06ee0a0f Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:35:31 +0200 Subject: [PATCH 3/4] REFAC: pylint --- python/otbtf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index a58d10d2..2083250a 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -37,7 +37,7 @@ from tqdm import tqdm # --------------------------------------------- GDAL to numpy types ---------------------------------------------------- -gdal_to_np_types = {1: 'uint8', +GDAL_TO_NP_TYPES = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', @@ -73,7 +73,7 @@ def read_as_np_arr(gdal_ds, as_patches=True): """ gdal_type = gdal_ds.GetRasterBand(1).DataType size_x = gdal_ds.RasterXSize - buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type]) + buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type]) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) if not as_patches: @@ -269,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase): else: # single-band raster buffer = np.expand_dims(buffer, axis=2) - return buffer.astype(gdal_to_np_types[gdal_type]) + return buffer.astype(GDAL_TO_NP_TYPES[gdal_type]) def get_sample(self, index): """ @@ -628,8 +628,8 @@ class TFRecords: """ data_converted = {} - for k, d in data.items(): - data_converted[k] = d.name + for key, value in data.items(): + data_converted[key] = value.name return data_converted @@ -644,7 +644,7 @@ class TFRecords: filepath = os.path.join(self.dirpath, f"{i}.records") with tf.io.TFRecordWriter(filepath) as writer: - for s in range(nb_sample): + for _ in range(nb_sample): sample = dataset.read_one_sample() serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in @@ -661,8 +661,8 @@ class TFRecords: :param filepath: Output file name """ - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) + with open(filepath, 'w') as file: + json.dump(data, file, indent=4) @staticmethod def load(filepath): -- GitLab From b1f064f0f11614ef8e40e9f168b50080d1ad2e14 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Fri, 22 Apr 2022 16:39:27 +0200 Subject: [PATCH 4/4] REFAC: pylint --- python/otbtf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/otbtf.py b/python/otbtf.py index 2083250a..d7e1a0b0 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -670,8 +670,8 @@ class TFRecords: Return data from pickle format. :param filepath: Input file name """ - with open(filepath, 'r') as f: - return json.load(f) + with open(filepath, 'r') as file: + return json.load(file) def convert_dataset_output_shapes(self, dataset): """ @@ -680,8 +680,8 @@ class TFRecords: """ output_shapes = {} - for key in dataset.output_shapes.keys(): - output_shapes[key] = (None,) + dataset.output_shapes[key] + for key, value in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + value self.save(output_shapes, self.output_shape_file) -- GitLab