diff --git a/.gitignore b/.gitignore index 92f7a9f9f04c0190411f10637286c68a30cb438f..49c4d222bc0ef2b2d128f6b639589436f1d66716 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ pyrightconfig.json models dinov2_base_518px_p14_reg4.lvd142m.pc22_subtrain2.8M_12n8A10040h115999it_over_125000_timm_compatible.pth docs/build +symbology-style.db diff --git a/encoder.py b/encoder.py index 4744a54b6554059763a1ea9acae5eef5926bfab1..6a1d2698d1f9e35cf9f9677476934bb2420ac90b 100644 --- a/encoder.py +++ b/encoder.py @@ -426,6 +426,17 @@ class EncoderAlgorithm(IAMAPAlgorithm): data_config = timm.data.resolve_model_data_config(model) _, h, w, = data_config['input_size'] + if torch.cuda.is_available() and self.use_gpu: + if self.cuda_id + 1 > torch.cuda.device_count(): + self.cuda_id = torch.cuda.device_count() - 1 + cuda_device = f'cuda:{self.cuda_id}' + device = f'cuda:{self.cuda_id}' + else: + self.batch_size = 1 + device = 'cpu' + + feedback.pushInfo(f'Device id: {device}') + if self.quantization: try : @@ -438,7 +449,6 @@ class EncoderAlgorithm(IAMAPAlgorithm): feedback.pushInfo(f'quantization impossible, using original model.') - transform = AugmentationSequential( T.ConvertImageDtype(torch.float32), # change dtype for normalize to be possible K.Normalize(self.means,self.sds), # normalize occurs only on raster, not mask @@ -467,16 +477,6 @@ class EncoderAlgorithm(IAMAPAlgorithm): self.load_feature = False feedback.pushWarning(f'\n !!!No available patch sample inside the chosen extent!!! \n') - if torch.cuda.is_available() and self.use_gpu: - if self.cuda_id + 1 > torch.cuda.device_count(): - self.cuda_id = torch.cuda.device_count() - 1 - cuda_device = f'cuda:{self.cuda_id}' - device = f'cuda:{self.cuda_id}' - else: - self.batch_size = 1 - device = 'cpu' - - feedback.pushInfo(f'Device id: {device}') feedback.pushInfo(f'model to dedvice') model.to(device=device) diff --git a/utils/trch.py b/utils/trch.py index c126c5489b94113959885870ee59dd10ffe31db1..be7c158f6e18fd6df0e990d9a7bd670f994af12d 100644 --- a/utils/trch.py +++ b/utils/trch.py @@ -21,4 +21,3 @@ def quantize_model(model, device): model, {nn.Linear}, dtype=torch.qint8 ) return model -