Skip to content
Snippets Groups Projects
Commit 446b4f93 authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

give option to save as int8 (but not fully working yet)

parent c92b5146
No related branches found
No related tags found
No related merge requests found
......@@ -87,6 +87,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
REMOVE_TEMP_FILES = 'REMOVE_TEMP_FILES'
TEMP_FILES_CLEANUP_FREQ = 'TEMP_FILES_CLEANUP_FREQ'
JSON_PARAM = 'JSON_PARAM'
OUT_DTYPE = 'OUT_DTYPE'
def initAlgorithm(self, config=None):
......@@ -312,6 +313,15 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
defaultValue=0,
)
self.out_dtype_opt = ['float32', 'int8']
dtype_param = QgsProcessingParameterEnum(
name=self.OUT_DTYPE,
description=self.tr(
'Data type of exported features (int8 saves space)'),
options=self.out_dtype_opt,
defaultValue=0,
)
json_param = QgsProcessingParameterFile(
name=self.JSON_PARAM,
description=self.tr(
......@@ -325,6 +335,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
for param in (
crs_param,
res_param,
dtype_param,
chkpt_param,
cuda_id_param,
merge_param,
......@@ -535,7 +546,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
features = features.detach().cpu().numpy()
feedback.pushInfo(f'Features shape {features.shape}')
self.save_features(features,sample['bbox'], current)
self.save_features(features,sample['bbox'], current,dtype=self.out_dtype)
feedback.pushInfo(f'Features saved')
if current <= last_batch_done + 1:
......@@ -595,6 +606,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
tiles = all_tiles,
dst_path = dst_path,
method = self.merge_method,
dtype= self.out_dtype,
)
self.remove_temp_files()
self.all_encoding_done = True
......@@ -728,6 +740,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
feature: np.ndarray,
bboxes: BoundingBox,
nbatch: int,
dtype: str = 'float32'
):
# iterate over batch_size dimension
......@@ -743,7 +756,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
height=height,
width=width,
count=channels,
dtype='float32',
dtype=dtype,
crs=self.crs.toWkt(),
transform=rio_transform
) as ds:
......@@ -808,6 +821,10 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
self.backbone_name = self.timm_backbone_opt[backbone_idx]
feedback.pushInfo(f'self.backbone_name:{self.backbone_name}')
dtype_idx = self.parameterAsEnum(
parameters, self.OUT_DTYPE, context)
self.out_dtype = self.out_dtype_opt[dtype_idx]
self.stride = self.parameterAsInt(
parameters, self.STRIDE, context)
self.size = self.parameterAsInt(
......
......@@ -48,6 +48,7 @@ class TestEncoderAlgorithm(unittest.TestCase):
'TEMP_FILES_CLEANUP_FREQ': 1000,
'WORKERS': 0,
'JSON_PARAM': 'NULL',
'OUT_DTYPE': 0,
}
result = self.algorithm.processAlgorithm(parameters, self.context, self.feedback)
expected_result_path = os.path.join(self.algorithm.output_subdir,'merged.tif')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment