Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is_last should not be checked forValueFunction #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Josh00-Lu
Copy link

@Josh00-Lu Josh00-Lu commented Mar 16, 2024

Thanks for your wonderful project.

I think, here, is_last should not be checked for ValueFunction, since there's no nn.Identity() here for the last block:

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.blocks.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                Downsample1d(dim_out)
            ]))

            if not is_last: # Should delete this line
                horizon = horizon // 2

Otherwise, the dimension may not be correctly handled:

e.g. For horizon=32 and dim_mults=(1, 4, 8), a dimension-mismatch error would appear:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x96 and 160x64)
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.blocks.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
                Downsample1d(dim_out)
            ]))

            horizon = horizon // 2

This bug will not appear when horizon is small (0 // 2 =0), but will appear for large horizon.

@Josh00-Lu Josh00-Lu changed the title fix bugs forValueFunction is_last should not be checked forValueFunction Mar 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant