-
Notifications
You must be signed in to change notification settings - Fork 31
/
demo6-optim_roughness_textures.py
79 lines (63 loc) · 3.23 KB
/
demo6-optim_roughness_textures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import argparse
import glob
import jittor as jt
from jittor import nn
import numpy as np
from skimage.io import imread, imsave
import tqdm
import imageio
import jrender as jr
jt.flags.use_cuda = 1
current_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(current_dir, 'data')
np.random.seed(1)
class Model(nn.Module):
def __init__(self, filename_obj, filename_ref):
super(Model, self).__init__()
# set template mesh
texture_size = 4
self.template_mesh = jr.Mesh.from_obj(filename_obj, texture_res=texture_size,load_texture=True, dr_type='softras')
self.vertices = (self.template_mesh.vertices).stop_grad()
self.faces = self.template_mesh.faces.stop_grad()
self.textures = self.template_mesh.textures.stop_grad()
self.metallic_textures = jt.zeros((1, self.faces.shape[1], texture_size * texture_size, 1)).float32() + 0.4
self.metallic_textures = self.metallic_textures.stop_grad()
self.roughness_textures = jt.ones((1, self.faces.shape[1], texture_size * texture_size, 1)).float32()
# load reference image
self.image_ref = jt.array(imread(filename_ref).astype('float32') / 255.).permute(2,0,1).unsqueeze(0).stop_grad()
# setup renderer
self.renderer = jr.Renderer(dr_type='softras')
def execute(self):
self.renderer.transform.set_eyes_from_angles(2.732, 30, 140)
image = self.renderer(self.vertices, self.faces, self.textures, metallic_textures=self.metallic_textures, roughness_textures=self.roughness_textures)
loss = jt.sum((image - self.image_ref).sqr())
return loss
def make_gif(filename):
with imageio.get_writer(filename, mode='I') as writer:
for filename in sorted(glob.glob('./tmp/_tmp_*.png')):
writer.append_data(imageio.imread(filename))
os.remove(filename)
writer.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-io', '--filename_obj', type=str, default=os.path.join(data_dir, 'obj/spot/spot_triangulated.obj'))
parser.add_argument('-ir', '--filename_ref', type=str, default=os.path.join(data_dir, 'ref/ref_roughness.png'))
parser.add_argument('-or', '--filename_output', type=str, default=os.path.join(data_dir, 'results/output_optim_roughness_textures'))
parser.add_argument('-g', '--gpu', type=int, default=0)
args = parser.parse_args()
os.makedirs(args.filename_output, exist_ok=True)
model = Model(args.filename_obj, args.filename_ref)
optimizer = nn.Adam([model.roughness_textures], lr=0.1, betas=(0.5,0.999))
loop = tqdm.tqdm(range(15))
for num in loop:
loop.set_description('Optimizing')
loss = model()
optimizer.step(loss)
model.renderer.transform.set_eyes_from_angles(2.732, 30, 140)
images = model.renderer(model.vertices, model.faces, model.textures, metallic_textures=model.metallic_textures, roughness_textures=model.roughness_textures)
image = images.numpy()[0].transpose((1, 2, 0))
imsave('./tmp/_tmp_%04d.png' % num, image)
make_gif(os.path.join(args.filename_output, 'result.gif'))
if __name__ == '__main__':
main()