Skip to content
Snippets Groups Projects
Commit bc388f6c authored by paul.tresson_ird.fr's avatar paul.tresson_ird.fr
Browse files

have pre-selected backbone options rather than typing timm arch names. resolves #7

parent 371c83da
No related branches found
No related tags found
No related merge requests found
......@@ -69,6 +69,7 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
BATCH_SIZE = 'BATCH_SIZE'
CUDA_ID = 'CUDA_ID'
BACKBONE_CHOICE = 'BACKBONE_CHOICE'
BACKBONE_OPT = 'BACKBONE_OPT'
MERGE_METHOD = 'MERGE_METHOD'
WORKERS = 'WORKERS'
PAUSES = 'PAUSES'
......@@ -205,13 +206,36 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
defaultValue=True
)
)
self.backbone_opt = [
'ViT base DINO',
'ViT tiny Imagenet (smallest)',
'ViT base MAE',
'SAM',
'--Empty--'
]
self.timm_backbone_opt = [
'vit_base_patch16_224.dino',
'vit_tiny_patch16_224.augreg_in21k',
'vit_base_patch16_224.mae',
'samvit_base_patch16.sa1b',
]
self.addParameter (
QgsProcessingParameterEnum(
name = self.BACKBONE_OPT,
description = self.tr(
"Pre-selected backbones if you don't know what to pick"),
defaultValue = 0,
options = self.backbone_opt,
)
)
self.addParameter (
QgsProcessingParameterString(
name = self.BACKBONE_CHOICE,
description = self.tr(
'Backbone choice (see huggingface.co/timm/)'),
defaultValue = 'vit_base_patch16_224.dino',
# defaultValue = 'vit_small_patch16_224.dino',
'Enter a architecture name if you want to test another backbone (see huggingface.co/timm/)'),
defaultValue = None,
optional=True,
)
)
......@@ -571,9 +595,19 @@ class EncoderAlgorithm(QgsProcessingAlgorithm):
ckpt_path = self.parameterAsFile(
parameters, self.CKPT, context)
self.backbone_name = self.parameterAsString(
## Use the given backbone name is any, use preselected models otherwise.
input_name = self.parameterAsString(
parameters, self.BACKBONE_CHOICE, context)
if input_name:
self.backbone_name = input_name
else:
backbone_idx = self.parameterAsEnum(
parameters, self.BACKBONE_OPT, context)
self.backbone_name = self.timm_backbone_opt[backbone_idx]
feedback.pushInfo(f'self.backbone_name:{self.backbone_name}')
self.stride = self.parameterAsInt(
parameters, self.STRIDE, context)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment