-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_v1_6.py
49 lines (37 loc) · 1.92 KB
/
gen_v1_6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""TemporalGAN Generator Version 1
In this version I use a dual stream for encoding part, which fuses the Sentiel2_T1 with Sentinel1_T2.
We use no Attetention module.
## Version 1.2: Initial_down -> No DownSample
## Version 1.3: Channel Attentino at the bottleneck
~~Version 1.5: GLAM at 8x8 downstreams~~
## Version 1.6: instead of 8x8, we put GLAM at 16x16
"""
import torch
import torch.nn as nn
from submodules.gen_cnn_block import Block
from submodules.cbam import ChannelAttention
from submodules.glam.glam import GLAM
from gen_v1_3 import Generator as Generator_v1_3
class Generator(Generator_v1_3):
def __init__(self, s2_in_channels=3, s1_in_channels=1, out_channels=1, features=64):
super().__init__(s2_in_channels, s1_in_channels, out_channels, features)
self.glam4_s2 = GLAM(in_channels=features * 8, num_reduced_channels=32, feature_map_size=16,kernel_size=5)
self.glam4_s1 = GLAM(in_channels=features * 8, num_reduced_channels=32, feature_map_size=16,kernel_size=5)
self.down4_s2 = nn.Sequential(self.down4_s2, self.glam4_s2)
self.down4_s1 = nn.Sequential(self.down4_s1, self.glam4_s1)
def test(summary=False):
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
s2 = torch.rand((1, 7, 256, 256)).to(torch.float32).to(device)
s1 = torch.rand((1, 1, 256, 256)).to(torch.float32).to(device)
model = Generator(s2_in_channels=7, s1_in_channels=1, features=64).to(device)
preds = model(s2,s1)
print(preds.shape)
print(torch.min(preds),torch.mean(preds),torch.max(preds),preds.dtype)
if summary:
from torchinfo import summary
summary(model, input_size=[(1, 7, 256, 256),(1, 1, 256, 256)], device='cpu',col_names=["input_size", "output_size", "num_params"],
col_width=20,
row_settings=["var_names"])
if __name__ == "__main__":
test(summary=False)