From 57664072cbf49da8f593ff089bd273e06b46339c Mon Sep 17 00:00:00 2001
From: ptresson <paul.tresson@ird.fr>
Date: Fri, 25 Oct 2024 14:08:27 +0200
Subject: [PATCH] move device definition to fix quantization

---
 .gitignore    |  1 +
 encoder.py    | 22 +++++++++++-----------
 utils/trch.py |  1 -
 3 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/.gitignore b/.gitignore
index 92f7a9f..49c4d22 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,3 +14,4 @@ pyrightconfig.json
 models
 dinov2_base_518px_p14_reg4.lvd142m.pc22_subtrain2.8M_12n8A10040h115999it_over_125000_timm_compatible.pth
 docs/build
+symbology-style.db
diff --git a/encoder.py b/encoder.py
index 4744a54..6a1d269 100644
--- a/encoder.py
+++ b/encoder.py
@@ -426,6 +426,17 @@ class EncoderAlgorithm(IAMAPAlgorithm):
         data_config = timm.data.resolve_model_data_config(model)
         _, h, w, = data_config['input_size']
 
+        if torch.cuda.is_available() and self.use_gpu:
+            if self.cuda_id + 1 > torch.cuda.device_count():
+                self.cuda_id = torch.cuda.device_count() - 1
+            cuda_device = f'cuda:{self.cuda_id}'
+            device = f'cuda:{self.cuda_id}'
+        else:
+            self.batch_size = 1
+            device = 'cpu'
+
+        feedback.pushInfo(f'Device id: {device}')
+
         if self.quantization:
 
             try :
@@ -438,7 +449,6 @@ class EncoderAlgorithm(IAMAPAlgorithm):
 
                 feedback.pushInfo(f'quantization impossible, using original model.')
 
-
         transform = AugmentationSequential(
                 T.ConvertImageDtype(torch.float32), # change dtype for normalize to be possible
                 K.Normalize(self.means,self.sds), # normalize occurs only on raster, not mask
@@ -467,16 +477,6 @@ class EncoderAlgorithm(IAMAPAlgorithm):
             self.load_feature = False
             feedback.pushWarning(f'\n !!!No available patch sample inside the chosen extent!!! \n')
 
-        if torch.cuda.is_available() and self.use_gpu:
-            if self.cuda_id + 1 > torch.cuda.device_count():
-                self.cuda_id = torch.cuda.device_count() - 1
-            cuda_device = f'cuda:{self.cuda_id}'
-            device = f'cuda:{self.cuda_id}'
-        else:
-            self.batch_size = 1
-            device = 'cpu'
-
-        feedback.pushInfo(f'Device id: {device}')
 
         feedback.pushInfo(f'model to dedvice')
         model.to(device=device)
diff --git a/utils/trch.py b/utils/trch.py
index c126c54..be7c158 100644
--- a/utils/trch.py
+++ b/utils/trch.py
@@ -21,4 +21,3 @@ def quantize_model(model, device):
             model, {nn.Linear}, dtype=torch.qint8
         )
     return model
-
-- 
GitLab