From 57664072cbf49da8f593ff089bd273e06b46339c Mon Sep 17 00:00:00 2001 From: ptresson <paul.tresson@ird.fr> Date: Fri, 25 Oct 2024 14:08:27 +0200 Subject: [PATCH] move device definition to fix quantization --- .gitignore | 1 + encoder.py | 22 +++++++++++----------- utils/trch.py | 1 - 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 92f7a9f..49c4d22 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ pyrightconfig.json models dinov2_base_518px_p14_reg4.lvd142m.pc22_subtrain2.8M_12n8A10040h115999it_over_125000_timm_compatible.pth docs/build +symbology-style.db diff --git a/encoder.py b/encoder.py index 4744a54..6a1d269 100644 --- a/encoder.py +++ b/encoder.py @@ -426,6 +426,17 @@ class EncoderAlgorithm(IAMAPAlgorithm): data_config = timm.data.resolve_model_data_config(model) _, h, w, = data_config['input_size'] + if torch.cuda.is_available() and self.use_gpu: + if self.cuda_id + 1 > torch.cuda.device_count(): + self.cuda_id = torch.cuda.device_count() - 1 + cuda_device = f'cuda:{self.cuda_id}' + device = f'cuda:{self.cuda_id}' + else: + self.batch_size = 1 + device = 'cpu' + + feedback.pushInfo(f'Device id: {device}') + if self.quantization: try : @@ -438,7 +449,6 @@ class EncoderAlgorithm(IAMAPAlgorithm): feedback.pushInfo(f'quantization impossible, using original model.') - transform = AugmentationSequential( T.ConvertImageDtype(torch.float32), # change dtype for normalize to be possible K.Normalize(self.means,self.sds), # normalize occurs only on raster, not mask @@ -467,16 +477,6 @@ class EncoderAlgorithm(IAMAPAlgorithm): self.load_feature = False feedback.pushWarning(f'\n !!!No available patch sample inside the chosen extent!!! \n') - if torch.cuda.is_available() and self.use_gpu: - if self.cuda_id + 1 > torch.cuda.device_count(): - self.cuda_id = torch.cuda.device_count() - 1 - cuda_device = f'cuda:{self.cuda_id}' - device = f'cuda:{self.cuda_id}' - else: - self.batch_size = 1 - device = 'cpu' - - feedback.pushInfo(f'Device id: {device}') feedback.pushInfo(f'model to dedvice') model.to(device=device) diff --git a/utils/trch.py b/utils/trch.py index c126c54..be7c158 100644 --- a/utils/trch.py +++ b/utils/trch.py @@ -21,4 +21,3 @@ def quantize_model(model, device): model, {nn.Linear}, dtype=torch.qint8 ) return model - -- GitLab