From 929dae88414026ac9dc760cb01d23e7d99a4a7ce Mon Sep 17 00:00:00 2001
From: Narcon Nicolas <nicolas.narcon@inrae.fr>
Date: Thu, 21 Apr 2022 16:55:58 +0200
Subject: [PATCH 1/4] ENH: generate samples of same type as initial raster

---
 python/otbtf.py | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index b28a1cc4..a1cf9bd4 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -58,8 +58,11 @@ def read_as_np_arr(gdal_ds, as_patches=True):
         False, the shape is (1, psz_y, psz_x, nb_channels)
     :return: Numpy array of dim 4
     """
-    buffer = gdal_ds.ReadAsArray()
+    gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
+                        10: 'complex64', 11: 'complex128'}
+    gdal_type = gdal_ds.GetRasterBand(1).DataType
     size_x = gdal_ds.RasterXSize
+    buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type])
     if len(buffer.shape) == 3:
         buffer = np.transpose(buffer, axes=(1, 2, 0))
     if not as_patches:
@@ -68,7 +71,7 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     else:
         n_elems = int(gdal_ds.RasterYSize / size_x)
         size_y = size_x
-    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
+    return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
 
 
 # -------------------------------------------------- Buffer class ------------------------------------------------------
@@ -244,8 +247,11 @@ class PatchesImagesReader(PatchesReaderBase):
 
     @staticmethod
     def _read_extract_as_np_arr(gdal_ds, offset):
+        gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
+                            10: 'complex64', 11: 'complex128'}
         assert gdal_ds is not None
         psz = gdal_ds.RasterXSize
+        gdal_type = gdal_ds.GetRasterBand(1).DataType
         yoff = int(offset * psz)
         assert yoff + psz <= gdal_ds.RasterYSize
         buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
@@ -254,7 +260,7 @@ class PatchesImagesReader(PatchesReaderBase):
         else:  # single-band raster
             buffer = np.expand_dims(buffer, axis=2)
 
-        return np.float32(buffer)
+        return buffer.astype(gdal_to_np_types[gdal_type])
 
     def get_sample(self, index):
         """
-- 
GitLab


From f1bda5b1336a6a8a1d63204175e11a885b635665 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 16:29:51 +0200
Subject: [PATCH 2/4] REFAC: gdal_to_np_type is global constant

---
 python/otbtf.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index a1cf9bd4..a58d10d2 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -34,6 +34,19 @@ import tensorflow as tf
 from osgeo import gdal
 from tqdm import tqdm
 
+# --------------------------------------------- GDAL to numpy types ----------------------------------------------------
+
+
+gdal_to_np_types = {1: 'uint8',
+                    2: 'uint16',
+                    3: 'int16',
+                    4: 'uint32',
+                    5: 'int32',
+                    6: 'float32',
+                    7: 'float64',
+                    10: 'complex64',
+                    11: 'complex128'}
+
 
 # ----------------------------------------------------- Helpers --------------------------------------------------------
 
@@ -58,8 +71,6 @@ def read_as_np_arr(gdal_ds, as_patches=True):
         False, the shape is (1, psz_y, psz_x, nb_channels)
     :return: Numpy array of dim 4
     """
-    gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
-                        10: 'complex64', 11: 'complex128'}
     gdal_type = gdal_ds.GetRasterBand(1).DataType
     size_x = gdal_ds.RasterXSize
     buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type])
@@ -247,8 +258,6 @@ class PatchesImagesReader(PatchesReaderBase):
 
     @staticmethod
     def _read_extract_as_np_arr(gdal_ds, offset):
-        gdal_to_np_types = {1: 'uint8', 2: 'uint16', 3: 'int16', 4: 'uint32', 5: 'int32', 6: 'float32', 7: 'float64',
-                            10: 'complex64', 11: 'complex128'}
         assert gdal_ds is not None
         psz = gdal_ds.RasterXSize
         gdal_type = gdal_ds.GetRasterBand(1).DataType
-- 
GitLab


From ae46d5e19cc49714b941bd749f3d876a06ee0a0f Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 16:35:31 +0200
Subject: [PATCH 3/4] REFAC: pylint

---
 python/otbtf.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index a58d10d2..2083250a 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -37,7 +37,7 @@ from tqdm import tqdm
 # --------------------------------------------- GDAL to numpy types ----------------------------------------------------
 
 
-gdal_to_np_types = {1: 'uint8',
+GDAL_TO_NP_TYPES = {1: 'uint8',
                     2: 'uint16',
                     3: 'int16',
                     4: 'uint32',
@@ -73,7 +73,7 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     """
     gdal_type = gdal_ds.GetRasterBand(1).DataType
     size_x = gdal_ds.RasterXSize
-    buffer = gdal_ds.ReadAsArray().astype(gdal_to_np_types[gdal_type])
+    buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type])
     if len(buffer.shape) == 3:
         buffer = np.transpose(buffer, axes=(1, 2, 0))
     if not as_patches:
@@ -269,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase):
         else:  # single-band raster
             buffer = np.expand_dims(buffer, axis=2)
 
-        return buffer.astype(gdal_to_np_types[gdal_type])
+        return buffer.astype(GDAL_TO_NP_TYPES[gdal_type])
 
     def get_sample(self, index):
         """
@@ -628,8 +628,8 @@ class TFRecords:
             """
             data_converted = {}
 
-            for k, d in data.items():
-                data_converted[k] = d.name
+            for key, value in data.items():
+                data_converted[key] = value.name
 
             return data_converted
 
@@ -644,7 +644,7 @@ class TFRecords:
 
             filepath = os.path.join(self.dirpath, f"{i}.records")
             with tf.io.TFRecordWriter(filepath) as writer:
-                for s in range(nb_sample):
+                for _ in range(nb_sample):
                     sample = dataset.read_one_sample()
                     serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
                     features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
@@ -661,8 +661,8 @@ class TFRecords:
         :param filepath: Output file name
         """
 
-        with open(filepath, 'w') as f:
-            json.dump(data, f, indent=4)
+        with open(filepath, 'w') as file:
+            json.dump(data, file, indent=4)
 
     @staticmethod
     def load(filepath):
-- 
GitLab


From b1f064f0f11614ef8e40e9f168b50080d1ad2e14 Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Fri, 22 Apr 2022 16:39:27 +0200
Subject: [PATCH 4/4] REFAC: pylint

---
 python/otbtf.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/python/otbtf.py b/python/otbtf.py
index 2083250a..d7e1a0b0 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -670,8 +670,8 @@ class TFRecords:
         Return data from pickle format.
         :param filepath: Input file name
         """
-        with open(filepath, 'r') as f:
-            return json.load(f)
+        with open(filepath, 'r') as file:
+            return json.load(file)
 
     def convert_dataset_output_shapes(self, dataset):
         """
@@ -680,8 +680,8 @@ class TFRecords:
         """
         output_shapes = {}
 
-        for key in dataset.output_shapes.keys():
-            output_shapes[key] = (None,) + dataset.output_shapes[key]
+        for key, value in dataset.output_shapes.keys():
+            output_shapes[key] = (None,) + value
 
         self.save(output_shapes, self.output_shape_file)
 
-- 
GitLab