diff --git a/tests/test_encoder.py b/tests/test_encoder.py
index 0469f4b5f21ab1e1c3dc1dd5fce36c4b50e4bad0..b0036b6faa09b39a6a5f53024fe926cc2403abbd 100644
--- a/tests/test_encoder.py
+++ b/tests/test_encoder.py
@@ -6,6 +6,10 @@ from qgis.core import (
         QgsProcessingFeedback,
         )
 
+import timm
+import torch
+from torchgeo.transforms import AugmentationSequential
+
 from ..encoder import EncoderAlgorithm
 
 ## for hashing without using to much memory
@@ -35,8 +39,44 @@ class TestEncoderAlgorithm(unittest.TestCase):
         assert result_file_hash == '018b6fc5d88014a7e515824d95ca8686'
         os.remove(expected_result_path)
 
+
+    def test_timm_create_model(self):
+
+        archs = [
+                'vit_base_patch16_224.dino',
+                'vit_tiny_patch16_224.augreg_in21k',
+                'vit_base_patch16_224.mae',
+                # 'samvit_base_patch16.sa1b',
+                ]
+        expected_output_size = [
+                torch.Size([1,197,768]),
+                torch.Size([1,197,192]),
+                torch.Size([1,197,768]),
+                # torch.Size([1, 256, 64, 64]),
+                ]
+
+        for arch, exp_feat_size in zip(archs, expected_output_size):
+
+            model = timm.create_model(
+                arch,
+                pretrained=True,
+                in_chans=6,
+                num_classes=0,
+                )
+            model = model.eval()
+
+            data_config = timm.data.resolve_model_data_config(model)
+            _, h, w, = data_config['input_size']
+            output = model.forward_features(torch.randn(1,6,h,w))
+
+            assert output.shape == exp_feat_size
+
+
+
+
 if __name__ == "__main__":
 
     test_encoder = TestEncoderAlgorithm()
     test_encoder.setUp()
     test_encoder.test_valid_parameters()
+    test_encoder.test_timm_create_model()