You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If you do not know the root cause of the problem / bug, and wish someone to help you, please
post according to this template:
🐛 Bugs / Unexpected behaviors
camera_position_optimization_with_differentiable_rendering.py does not converge with multiple runs. I was trying to repeat the output with teapot.obj file for multiple runs. However, the code does not converge everytime. Sometimes, the self.camera_position becomes nan. I also tried seeding everything but this does not change the abrupt behavior.
importosimportsysimporttorchiftorch.__version__=='1.6.0+cu101'andsys.platform.startswith('linux'):
get_ipython().system('pip install pytorch3d')
else:
need_pytorch3d=Falsetry:
importpytorch3dexceptModuleNotFoundError:
need_pytorch3d=Trueifneed_pytorch3d:
get_ipython().system('curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz')
get_ipython().system('tar xzf 1.10.0.tar.gz')
os.environ["CUB_HOME"] =os.getcwd() +"/cub-1.10.0"get_ipython().system("pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'")
importosimporttorchimportnumpyasnpfromtqdm.notebookimporttqdmimportimageioimporttorch.nnasnnimporttorch.nn.functionalasFimportmatplotlib.pyplotaspltfromskimageimportimg_as_ubyteimportrandom# io utilsfrompytorch3d.ioimportload_obj# datastructuresfrompytorch3d.structuresimportMeshes# 3D transformations functionsfrompytorch3d.transformsimportRotate, Translate# rendering componentsfrompytorch3d.rendererimport (
FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)
importargparseparser=argparse.ArgumentParser("Calculate ego motion of the camera")
parser.add_argument("--obj_file_path" , type=str , default="/home/abhinav/Desktop/data_3d/teapot.obj", help='input obj file')
parser.add_argument("--seed" , type=int , default=0 , help='seed')
parser.add_argument("--lr" , type=float, default=0.05)
parser.add_argument("--distance" , type=float, default=3 , help='distance')
parser.add_argument("--elevation" , type=float, default=50 , help='elevation')
parser.add_argument("--azimuth" , type=float, default=0 , help='azimuth')
parser.add_argument("--num_iter" , type=int , default=200, help='num_iter')
parser.add_argument("--print_frequency" , type=int , default=20)
args=parser.parse_args()
obj_file_path=args.obj_file_pathseed=args.seedlr=args.lr# The world coordinate system is defined as +Y up, +X left and +Z in. The teapot in world coordinates has the spout pointing to the left.# We defined a camera which is positioned on the positive z axis hence sees the spout to the right.# Select the viewpoint using spherical angles distance=args.distance# distance from camera to the objectelevation=args.elevation# angle of elevation in degreesazimuth=args.azimuth# angle of azimuth/yaw. No rotation so the camera is positioned on the +Z axis.num_iter=args.num_iterprint_frequency=args.print_frequencylight_location= (2.0, 2.0, -2.0)
camera_init_position=np.array([3.0, 6.9, +2.5], dtype=np.float32)
show_current_target=Falseobj_file_basename=os.path.basename(obj_file_path)
obj_file_basename_no_ext=obj_file_basename.split(".")[0]
# We will save images periodically and compose them into a GIF.filename_output=os.path.join(os.getcwd(), obj_file_basename_no_ext+"_demo.gif")
# seed everythingtorch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=Truetorch.backends.cudnn.benchmark=False# torch.backends.cudnn.enabled = False# torch.set_deterministic(True)# Set the cuda deviceiftorch.cuda.is_available():
device=torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device=torch.device("cpu")
ifobj_file_path.endswith(".obj"):
# Load the obj and ignore the textures and materials.verts, faces_idx, _=load_obj(obj_file_path)
faces=faces_idx.verts_idxprint("Input file = {}".format(obj_file_path))
print("Number of vertices= {}".format(verts.shape[0]))
print("Number of faces = {}".format(faces.shape[0]))
print("min x= {: .2f}, max x= {: .2f}".format(torch.min(verts[:, 0]).item(), torch.max(verts[:, 0]).item()))
print("min y= {: .2f}, max y= {: .2f}".format(torch.min(verts[:, 1]).item(), torch.max(verts[:, 1]).item()))
print("min z= {: .2f}, max z= {: .2f}".format(torch.min(verts[:, 2]).item(), torch.max(verts[:, 2]).item()))
print("Distance = {}".format(distance))
print("Elevation= {}".format(elevation))
print("Azimuth = {}".format(azimuth))
print("Seed = {}".format(seed))
print("lr = {}".format(lr))
print("Num iter = {}".format(num_iter))
# Initialize each vertex to be white in color.verts_rgb=torch.ones_like(verts)[None] # (1, V, 3)textures=TexturesVertex(verts_features=verts_rgb.to(device))
# Create a Meshes object for the teapot. Here we have only one mesh in the batch.teapot_mesh=Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures
)
# ## 2. Optimization setup# ### Create a renderer# A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a# number of subcomponents such as a **camera** (orthgraphic/perspective). Here we initialize some of# these components and use default values for the rest.# For optimizing the camera position we will use a renderer which produces a **silhouette** of the# object only and does not apply any **lighting** or **shading**. We will also initialize another# renderer which applies full **phong shading** and use this for visualizing the outputs.# Initialize a perspective camera.cameras=FoVPerspectiveCameras(device=device)
# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of # edges. Refer to blending.py for more details. blend_params=BlendParams(sigma=1e-4, gamma=1e-4)
# Define the settings for rasterization and shading. Here we set the output image to be of size# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that # the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for # explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of # the difference between naive and coarse-to-fine rasterization. raster_settings=RasterizationSettings(
image_size=256,
blur_radius=np.log(1./1e-4-1.) *blend_params.sigma,
faces_per_pixel=100,
)
# Create a silhouette mesh renderer by composing a rasterizer and a shader. silhouette_renderer=MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=SoftSilhouetteShader(blend_params=blend_params)
)
# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.raster_settings=RasterizationSettings(
image_size=256,
blur_radius=0.0,
faces_per_pixel=1,
)
# We can add a point light in front of the object. lights=PointLights(device=device, location=(light_location,))
phong_renderer=MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)
# ### Create a reference image# We will first position the teapot and generate an image. We use helper functions to rotate the# teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use# both renderers and visualize the silhouette and full shaded image.# Get the position of the camera based on the spherical anglesR, T=look_at_view_transform(distance, elevation, azimuth, device=device)
# Render the teapot providing the values of R and T. silhouete=silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref=phong_renderer(meshes_world=teapot_mesh, R=R, T=T)
silhouete=silhouete.cpu().numpy()
image_ref=image_ref.cpu().numpy()
# plt.figure(figsize=(10, 10))# plt.subplot(1, 2, 1)# plt.imshow(silhouete.squeeze()[..., 3]) # only plot the alpha channel of the RGBA image# plt.grid(False)# plt.subplot(1, 2, 2)# plt.imshow(image_ref.squeeze())# plt.grid(False)# ### Set up a basic model classModel(nn.Module):
def__init__(self, meshes, renderer, image_ref):
super().__init__()
self.meshes=meshesself.device=meshes.deviceself.renderer=renderer# Get the silhouette of the reference RGB image by finding all non-white pixel values. image_ref=torch.from_numpy((image_ref[..., :3].max(-1) !=1).astype(np.float32))
self.register_buffer('image_ref', image_ref)
# Create an optimizable parameter for the x, y, z position of the camera. self.camera_position=nn.Parameter(
torch.from_numpy(camera_init_position).to(meshes.device))
self.previous_camera_position=Nonedefforward(self):
# Render the image using the updated camera position. Based on the new position of the# camer we calculate the rotation and translation matricesR=look_at_rotation(self.camera_position[None, :], device=self.device) # (1, 3, 3)if(torch.any(torch.isnan(R[0])) ):
print(self.camera_position)
print(R[0])
sys.exit(0)
T=-torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0] # (1, 3)image=self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
# Calculate the silhouette lossloss=torch.sum((image[..., 3] -self.image_ref) **2)
returnloss, image# Initialize a model using the renderer, mesh and reference imagemodel=Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)
# Create an optimizer. Here we are using Adam and we pass in the parameters of the modeloptimizer=torch.optim.Adam(model.parameters(), lr=lr)
# ### Visualize the starting position and the reference positionwriter=imageio.get_writer(filename_output, mode='I', duration=0.3)
_, image_init=model()
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
plt.grid(False)
plt.xticks([]); plt.yticks([])
plt.title("Starting position")
plt.grid(False)
plt.axis("off")
# plt.xticks([]); plt.yticks([])plt.subplot(1, 2, 2)
plt.imshow(model.image_ref.cpu().numpy().squeeze())
plt.grid(False)
plt.title("Reference silhouette");
plt.axis("off")
ifshow_current_target:
plt.show()
# ## 4. Run the optimization # # We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./teapot_optimization_demo.gif` for a cool gif of the optimization process!foriinrange(num_iter):
model.train()
optimizer.zero_grad()
loss, _=model()
loss.backward()
optimizer.step()
ifloss.clone().item() <200:
break# Save outputs to create a GIF. if (i+1) %print_frequency==0:
model.eval()
R=look_at_rotation(model.camera_position[None, :], device=model.device)
T=-torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0] # (1, 3)image=phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
image=image[0, ..., :3].detach().squeeze().cpu().numpy()
image=img_as_ubyte(image)
writer.append_data(image)
plt.figure()
plt.imshow(image[..., :3])
plt.title("iter: {:5d}, loss: {:.4f}".format(i+1, loss.data))
plt.grid("off")
plt.axis("off")
print("Optimizing iter: {:5d}, loss {:.4f}".format(i+1, loss.data))
writer.close()
print("=> Saving to {}".format(filename_output))
Environment:
a. Ubuntu 18.04
b. Pytorch 1.7.1
c. Torchvision 0.8.2
d. Cuda 10.2
I ran into this issue recently as well. Should the camera_position_optimization_with_differentiable_rendering.ipynb demo be updated with the with 'perspective_correct=False' fix?
If you do not know the root cause of the problem / bug, and wish someone to help you, please
post according to this template:
🐛 Bugs / Unexpected behaviors
camera_position_optimization_with_differentiable_rendering.py does not converge with multiple runs. I was trying to repeat the output with teapot.obj file for multiple runs. However, the code does not converge everytime. Sometimes, the
self.camera_position
becomesnan
. I also tried seeding everything but this does not change the abrupt behavior.This code is borrowed from the Pytorch3D tutorial
Instructions To Reproduce the Issue:
Environment:
a. Ubuntu 18.04
b. Pytorch 1.7.1
c. Torchvision 0.8.2
d. Cuda 10.2
The text was updated successfully, but these errors were encountered: