Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address sample_sequence_length greater than min_length_time_axis #45

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mickvangelderen
Copy link
Contributor

@mickvangelderen mickvangelderen commented Nov 9, 2024

Based on #43, please review that first and if it wasn't merged yet, please review this PR per commit.

The goal of this PR is to make the tests not generate flashbax warnings and to prevent code from being added that would trigger a flashbax warning.

I am not sure that the suggested changes in the last commit are right.

I am unsure why min_length_time_axis=min_length // add_batch_size + 1 is used in the trajectory buffer source code, particularly the + 1. I am also unsure why the min_length fixture is defined as:

@pytest.fixture
def min_length(sample_batch_size: int) -> int:
    return int(sample_batch_size + 1)

I could understand if it was defined in terms of add_batch_size, but it is not.

@SimonDuToit
Copy link
Contributor

SimonDuToit commented Dec 10, 2024

I think the + 1 in min_length_time_axis=min_length // add_batch_size + 1 is just to be conservative. This matches how max_length_time_axis is calculated.
The min_length fixture is defined to be viable, since it has to be at least as big as sample_size.

Copy link
Contributor

@SimonDuToit SimonDuToit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why its necessary to replace min_length with add_batch_size. Otherwise it makes sense, please just see my comment on the merge conflict.

buffer = flat_buffer.make_flat_buffer(
max_length, min_length, sample_batch_size, False, int(min_length + 10)
max_length, add_batch_size, sample_batch_size, False, add_batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reasoning for this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length_time_axis = 3, min_length_time_axis = 1, add_batch_size = 75, sample_sequence_length = 2, period = 1, max_size = None

    def validate_trajectory_buffer_args(
        max_length_time_axis: Optional[int],
        min_length_time_axis: int,
        add_batch_size: int,
        sample_sequence_length: int,
        period: int,
        max_size: Optional[int],
    ) -> None:
        """Validate the arguments of the trajectory buffer."""
    
        validate_size(max_length_time_axis, max_size, add_batch_size)
    
        if max_size is not None:
            max_length_time_axis = max_size // add_batch_size
    
        if sample_sequence_length > min_length_time_axis:
>           warnings.warn(
                "`sample_sequence_length` greater than `min_length_time_axis`, therefore "
                "overriding `min_length_time_axis`"
                "to be set to `sample_sequence_length`, as we need at least `sample_sequence_length` "
                "timesteps added to the buffer before we can sample.",
                stacklevel=1,
            )
E           UserWarning: `sample_sequence_length` greater than `min_length_time_axis`, therefore overriding `min_length_time_axis`to be set to `sample_sequence_length`, as we need at least `sample_sequence_length` timesteps added to the buffer before we can sample.

The warning checks that sample_sequence_length <= min_length_time_axis. The flat buffer creates a trajectory buffer with sample_sequence_length = 2 and min_length_time_axis = min_length // add_batch_size + 1 = 1 // (1 + 10) + 1 = 0 + 1 = 1. When calling create trajectory buffer you need to ensure that min_length >= add_batch_size hence the change. Perhaps it is not the right solution.

pyproject.toml Show resolved Hide resolved
@@ -92,7 +92,7 @@ def test_mixed_trajectory_sample(
for i in range(3):
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=200 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a sensible change.

@mickvangelderen mickvangelderen force-pushed the mick/fix-sample-seq-len-gt-min-len-time branch from c485ecf to 082261d Compare December 10, 2024 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants