From d1cd2d88cec95f8804afd4f67253a87fa76b7357 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 8 Nov 2024 17:02:30 +0000 Subject: [PATCH] Should be good to go for now, missing some more tests --- .github/workflows/gopy.yml | 4 +- .gitignore | 1 + cmd/main.go | 16 +- pkg/architecture.go | 1 + pkg/client.go | 68 +++++-- pkg/generator_db.go | 35 +++- pkg/generator_filesystem.go | 14 +- pkg/serdes.go | 2 +- pkg/worker_http.go | 2 +- python/benchmark_db.py | 18 +- python/benchmark_filesystem.py | 21 +- python/test_datago_db.py | 186 ++++++++++++++++++ python/test_datago_filesystem.py | 28 +++ .../datago_test.cpython-311-pytest-8.3.3.pyc | Bin 4019 -> 0 bytes .../test_datago.cpython-311-pytest-8.3.3.pyc | Bin 4152 -> 0 bytes python/tests/test_datago.py | 54 ----- requirements-tests.txt | 7 + requirements.txt | 1 - tests/client_test.go | 14 +- 19 files changed, 353 insertions(+), 119 deletions(-) create mode 100644 python/test_datago_db.py create mode 100644 python/test_datago_filesystem.py delete mode 100644 python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc delete mode 100644 python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc delete mode 100644 python/tests/test_datago.py create mode 100644 requirements-tests.txt diff --git a/.github/workflows/gopy.yml b/.github/workflows/gopy.yml index a23a360..43fa4c8 100644 --- a/.github/workflows/gopy.yml +++ b/.github/workflows/gopy.yml @@ -64,5 +64,5 @@ jobs: run: | ls - python3 -m pip install -r requirements.txt - pytest -xv python/tests/* + python3 -m pip install -r requirements-tests.txt + pytest -xv python/* diff --git a/.gitignore b/.gitignore index 70d0b2d..878658d 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ go.work.sum build __pycache__ +*.pyc diff --git a/cmd/main.go b/cmd/main.go index e77b7aa..45ac8f7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,11 +12,13 @@ import ( func main() { // Define flags - config := datago.DatagoConfig{} - config.SetDefaults() + config := datago.GetDatagoConfig() - sourceConfig := datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")} + sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")} sourceConfig.PageSize = 10 + sourceConfig.Rank = 0 + sourceConfig.WorldSize = 1 + config.ImageConfig = datago.ImageTransformConfig{ DefaultImageSize: 1024, DownsamplingRatio: 32, @@ -26,8 +28,8 @@ func main() { config.Concurrency = *flag.Int("concurrency", 64, "The number of concurrent http requests to make") config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load") config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served") + config.Limit = *flag.Int("limit", 2000, "The number of items to fetch") - limit := flag.Int("limit", 2000, "The number of items to fetch") profile := flag.Bool("profile", false, "Whether to profile the code") // Parse the flags and instantiate the client @@ -65,10 +67,10 @@ func main() { // Fetch all of the binary payloads as they become available // NOTE: This is useless, just making sure that we empty the payloads channel - for i := 0; i < *limit; i++ { + for { sample := dataroom_client.GetSample() if sample.ID == "" { - fmt.Println("No more samples ", i, " samples served") + fmt.Println("No more samples") break } } @@ -78,7 +80,7 @@ func main() { // Calculate the elapsed time elapsedTime := time.Since(startTime) - fps := float64(*limit) / elapsedTime.Seconds() + fps := float64(config.Limit) / elapsedTime.Seconds() fmt.Printf("Total execution time: %.2f \n", elapsedTime.Seconds()) fmt.Printf("Average throughput: %.2f samples per second\n", fps) } diff --git a/pkg/architecture.go b/pkg/architecture.go index 6a17c70..f9b86fa 100644 --- a/pkg/architecture.go +++ b/pkg/architecture.go @@ -23,6 +23,7 @@ type Sample struct { ID string Source string Attributes map[string]interface{} + DuplicateState int Image ImagePayload Masks map[string]ImagePayload AdditionalImages map[string]ImagePayload diff --git a/pkg/client.go b/pkg/client.go index 6d215f4..4a714be 100644 --- a/pkg/client.go +++ b/pkg/client.go @@ -30,6 +30,7 @@ type DataSourceConfig struct { PageSize int `json:"page_size"` Rank int `json:"rank"` WorldSize int `json:"world_size"` + Limit int `json:"limit"` } type ImageTransformConfig struct { @@ -41,7 +42,7 @@ type ImageTransformConfig struct { PreEncodeImages bool `json:"pre_encode_images"` } -func (c *ImageTransformConfig) SetDefaults() { +func (c *ImageTransformConfig) setDefaults() { c.DefaultImageSize = 512 c.DownsamplingRatio = 16 c.MinAspectRatio = 0.5 @@ -57,17 +58,25 @@ type DatagoConfig struct { PrefetchBufferSize int `json:"prefetch_buffer_size"` SamplesBufferSize int `json:"samples_buffer_size"` Concurrency int `json:"concurrency"` + Limit int `json:"limit"` } -func (c *DatagoConfig) SetDefaults() { - dbConfig := GeneratorDBConfig{} - dbConfig.SetDefaults() +func (c *DatagoConfig) setDefaults() { + dbConfig := SourceDBConfig{} + dbConfig.setDefaults() c.SourceConfig = dbConfig - c.ImageConfig.SetDefaults() + c.ImageConfig.setDefaults() c.PrefetchBufferSize = 64 c.SamplesBufferSize = 32 c.Concurrency = 64 + c.Limit = 0 +} + +func GetDatagoConfig() DatagoConfig { + config := DatagoConfig{} + config.setDefaults() + return config } func DatagoConfigFromJSON(jsonString string) DatagoConfig { @@ -80,21 +89,30 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig { sourceConfig, err := json.Marshal(tempConfig["source_config"]) if err != nil { + fmt.Println("Error marshalling source_config", tempConfig["source_config"], err) log.Panicf("Error marshalling source_config: %v", err) } + // Unmarshal the source config based on the source type + // NOTE: The undefined fields will follow the default values switch tempConfig["source_type"] { case string(SourceTypeDB): - var dbConfig GeneratorDBConfig + dbConfig := SourceDBConfig{} + dbConfig.setDefaults() + err = json.Unmarshal(sourceConfig, &dbConfig) if err != nil { + fmt.Println("Error unmarshalling DB config", sourceConfig, err) log.Panicf("Error unmarshalling DB config: %v", err) } config.SourceConfig = dbConfig case string(SourceTypeFileSystem): - var fsConfig GeneratorFileSystemConfig + fsConfig := SourceFileSystemConfig{} + fsConfig.setDefaults() + err = json.Unmarshal(sourceConfig, &fsConfig) if err != nil { + fmt.Println("Error unmarshalling Filesystem config", sourceConfig, err) log.Panicf("Error unmarshalling FileSystem config: %v", err) } config.SourceConfig = fsConfig @@ -127,7 +145,9 @@ type DatagoClient struct { waitGroup *sync.WaitGroup cancel context.CancelFunc - ImageConfig ImageTransformConfig + imageConfig ImageTransformConfig + servedSamples int + limit int // Flexible generator, backend and dispatch goroutines generator Generator @@ -157,14 +177,14 @@ func GetClient(config DatagoConfig) *DatagoClient { fmt.Println(reflect.TypeOf(config.SourceConfig)) switch config.SourceConfig.(type) { - case GeneratorDBConfig: + case SourceDBConfig: fmt.Println("Creating a DB-backed dataloader") - dbConfig := config.SourceConfig.(GeneratorDBConfig) + dbConfig := config.SourceConfig.(SourceDBConfig) generator = newDatagoGeneratorDB(dbConfig) backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency} - case GeneratorFileSystemConfig: + case SourceFileSystemConfig: fmt.Println("Creating a FileSystem-backed dataloader") - fsConfig := config.SourceConfig.(GeneratorFileSystemConfig) + fsConfig := config.SourceConfig.(SourceFileSystemConfig) generator = newDatagoGeneratorFileSystem(fsConfig) backend = BackendFileSystem{config: &config, concurrency: config.Concurrency} default: @@ -177,7 +197,9 @@ func GetClient(config DatagoConfig) *DatagoClient { chanPages: make(chan Pages, 2), chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize), chanSamples: make(chan Sample, config.SamplesBufferSize), - ImageConfig: config.ImageConfig, + imageConfig: config.ImageConfig, + servedSamples: 0, + limit: config.Limit, context: nil, cancel: nil, waitGroup: nil, @@ -218,13 +240,13 @@ func (c *DatagoClient) Start() { // Optionally crop and resize the images and masks on the fly var arAwareTransform *ARAwareTransform = nil - if c.ImageConfig.CropAndResize { + if c.imageConfig.CropAndResize { fmt.Println("Cropping and resizing images") - fmt.Println("Base image size | downsampling ratio | min | max:", c.ImageConfig.DefaultImageSize, c.ImageConfig.DownsamplingRatio, c.ImageConfig.MinAspectRatio, c.ImageConfig.MaxAspectRatio) - arAwareTransform = newARAwareTransform(c.ImageConfig) + fmt.Println("Base image size | downsampling ratio | min | max:", c.imageConfig.DefaultImageSize, c.imageConfig.DownsamplingRatio, c.imageConfig.MinAspectRatio, c.imageConfig.MaxAspectRatio) + arAwareTransform = newARAwareTransform(c.imageConfig) } - if c.ImageConfig.PreEncodeImages { + if c.imageConfig.PreEncodeImages { fmt.Println("Pre-encoding images, we'll return serialized JPG and PNG bytes") } @@ -247,7 +269,7 @@ func (c *DatagoClient) Start() { wg.Add(1) go func() { defer wg.Done() - c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.ImageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them + c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.imageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them }() c.waitGroup = &wg @@ -255,15 +277,23 @@ func (c *DatagoClient) Start() { // Get a deserialized sample from the client func (c *DatagoClient) GetSample() Sample { - if c.cancel == nil { + if c.cancel == nil && c.servedSamples == 0 { fmt.Println("Dataroom client not started. Starting it on the first sample, this adds some initial latency") fmt.Println("Please consider starting the client in anticipation by calling .Start()") c.Start() } + if c.limit > 0 && c.servedSamples == c.limit { + fmt.Println("Reached the limit of samples to serve, stopping the client") + c.Stop() + return Sample{} + } + if sample, ok := <-c.chanSamples; ok { + c.servedSamples++ return sample } + fmt.Println("chanSamples closed, no more samples to serve") return Sample{} } diff --git a/pkg/generator_db.go b/pkg/generator_db.go index c6b0310..73150a3 100644 --- a/pkg/generator_db.go +++ b/pkg/generator_db.go @@ -29,6 +29,7 @@ type urlLatent struct { type dbSampleMetadata struct { Id string `json:"id"` Attributes map[string]interface{} `json:"attributes"` + DuplicateState int `json:"duplicate_state"` ImageDirectURL string `json:"image_direct_url"` Latents []urlLatent `json:"latents"` Tags []string `json:"tags"` @@ -73,7 +74,7 @@ type dbRequest struct { } // -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type GeneratorDBConfig struct { +type SourceDBConfig struct { DataSourceConfig Sources string `json:"sources"` RequireImages bool `json:"require_images"` @@ -86,7 +87,9 @@ type GeneratorDBConfig struct { LacksMasks string `json:"lacks_masks"` HasLatents string `json:"has_latents"` LacksLatents string `json:"lacks_latents"` - ReturnLatents string `json:"return_latents"` + + ReturnLatents string `json:"return_latents"` + ReturnDuplicateState bool `json:"return_duplicate_state"` MinShortEdge int `json:"min_short_edge"` MaxShortEdge int `json:"max_short_edge"` @@ -95,7 +98,7 @@ type GeneratorDBConfig struct { RandomSampling bool `json:"random_sampling"` } -func (c *GeneratorDBConfig) SetDefaults() { +func (c *SourceDBConfig) setDefaults() { c.PageSize = 512 c.Rank = -1 c.WorldSize = -1 @@ -112,15 +115,17 @@ func (c *GeneratorDBConfig) SetDefaults() { c.HasLatents = "" c.LacksLatents = "" c.ReturnLatents = "" + c.ReturnDuplicateState = false c.MinShortEdge = -1 c.MaxShortEdge = -1 c.MinPixelCount = -1 c.MaxPixelCount = -1 c.RandomSampling = false + } -func (c *GeneratorDBConfig) getDbRequest() dbRequest { +func (c *SourceDBConfig) getDbRequest() dbRequest { fields := "attributes,image_direct_url" if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 { @@ -142,6 +147,11 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { fmt.Println("Including embeddings") } + if c.ReturnDuplicateState { + fields += ",duplicate_state" + fmt.Println("Including duplicate state") + } + // Report some config data fmt.Println("Rank | World size:", c.Rank, c.WorldSize) fmt.Println("Sources:", c.Sources, "| Fields:", fields) @@ -153,6 +163,13 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { return fmt.Sprintf("%d", val) } + // Align rank and worldsize with the partitioning + if c.WorldSize < 2 { + // No partitioning + c.WorldSize = -1 + c.Rank = -1 + } + return dbRequest{ fields: fields, sources: c.Sources, @@ -176,12 +193,18 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { } } +func GetSourceDBConfig() SourceDBConfig { + config := SourceDBConfig{} + config.setDefaults() + return config +} + type datagoGeneratorDB struct { baseRequest http.Request - config GeneratorDBConfig + config SourceDBConfig } -func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB { +func newDatagoGeneratorDB(config SourceDBConfig) datagoGeneratorDB { request := config.getDbRequest() api_key := os.Getenv("DATAROOM_API_KEY") diff --git a/pkg/generator_filesystem.go b/pkg/generator_filesystem.go index 5d2aec7..ad0dfba 100644 --- a/pkg/generator_filesystem.go +++ b/pkg/generator_filesystem.go @@ -18,12 +18,12 @@ type fsSampleMetadata struct { } // -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type GeneratorFileSystemConfig struct { +type SourceFileSystemConfig struct { DataSourceConfig RootPath string `json:"root_path"` } -func (c *GeneratorFileSystemConfig) SetDefaults() { +func (c *SourceFileSystemConfig) setDefaults() { c.PageSize = 512 c.Rank = 0 c.WorldSize = 1 @@ -31,12 +31,18 @@ func (c *GeneratorFileSystemConfig) SetDefaults() { c.RootPath = os.Getenv("DATAROOM_TEST_FILESYSTEM") } +func GetSourceFileSystemConfig() SourceFileSystemConfig { + config := SourceFileSystemConfig{} + config.setDefaults() + return config +} + type datagoGeneratorFileSystem struct { extensions set - config GeneratorFileSystemConfig + config SourceFileSystemConfig } -func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem { +func newDatagoGeneratorFileSystem(config SourceFileSystemConfig) datagoGeneratorFileSystem { supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"} var extensionsMap = make(set) for _, ext := range supported_img_extensions { diff --git a/pkg/serdes.go b/pkg/serdes.go index 57b535a..417b29a 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -185,7 +185,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware return nil, -1., err_report } -func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample { +func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample { // Per sample work: // - fetch the raw payloads // - deserialize / decode, depending on the types diff --git a/pkg/worker_http.go b/pkg/worker_http.go index e3588b3..bd8f668 100644 --- a/pkg/worker_http.go +++ b/pkg/worker_http.go @@ -7,7 +7,7 @@ import ( ) type BackendHTTP struct { - config *GeneratorDBConfig + config *SourceDBConfig concurrency int } diff --git a/python/benchmark_db.py b/python/benchmark_db.py index a648256..26b5601 100644 --- a/python/benchmark_db.py +++ b/python/benchmark_db.py @@ -1,10 +1,9 @@ from datago import datago # type: ignore import time -import typer from tqdm import tqdm import numpy as np from go_types import go_array_to_pil_image, go_array_to_numpy - +import typer def benchmark( source: str = typer.Option("SOURCE", help="The source to test out"), @@ -18,21 +17,24 @@ def benchmark( test_latents: bool = typer.Option(True, help="Test latents"), ): print(f"Running benchmark for {source} - {limit} samples") - client_config = datago.DatagoConfig() - client_config.SetDefaults() + + # Get a generic client config + client_config = datago.GetDatagoConfig() client_config.ImageConfig.CropAndResize = crop_and_resize - source_config = datago.GeneratorDBConfig() - source_config.SetDefaults() + # Specify the source parameters as you see fit + source_config = datago.GetSourceDBConfig() source_config.Sources = source source_config.RequireImages = require_images source_config.RequireEmbeddings = require_embeddings source_config.HasMasks = "segmentation_mask" if test_masks else "" source_config.HasLatents = "caption_latent_t5xxl" if test_latents else "" - client_config.SourceConfig = source_config + # Get a new client instance, happy benchmarking + client_config.SourceConfig = source_config client = datago.GetClient(client_config) - client.Start() + + client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance) start = time.time() # Make sure in the following that we compare apples to apples, meaning in that case diff --git a/python/benchmark_filesystem.py b/python/benchmark_filesystem.py index 6f26b17..266e416 100644 --- a/python/benchmark_filesystem.py +++ b/python/benchmark_filesystem.py @@ -1,11 +1,10 @@ from datago import datago # type: ignore import time -import typer from tqdm import tqdm from go_types import go_array_to_pil_image import os import json - +import typer def benchmark( root_path: str = typer.Option( @@ -15,7 +14,7 @@ def benchmark( crop_and_resize: bool = typer.Option( True, help="Crop and resize the images on the fly" ), - concurrency: int = typer.Option(32, help="The number of concurrent workers"), + concurrency: int = typer.Option(64, help="The number of coroutines"), compare_torch: bool = typer.Option(True, help="Compare against torch dataloader"), ): print(f"Running benchmark for {root_path} - {limit} samples") @@ -38,11 +37,13 @@ def benchmark( "prefetch_buffer_size": 64, "samples_buffer_size": 128, "concurrency": concurrency, + "limit": limit, } client = datago.GetClientFromJSON(json.dumps(client_config)) - client.Start() + start = time.time() + client.Start() # Make sure in the following that we compare apples to apples, meaning in that case # that we materialize the payloads in the python scope in the expected format @@ -53,6 +54,10 @@ def benchmark( if sample.ID and hasattr(sample, "Image"): img = go_array_to_pil_image(sample.Image) + if sample.ID is None: + print("No more samples") + break + fps = limit / (time.time() - start) print(f"Datago FPS {fps:.2f}") client.Stop() @@ -71,7 +76,7 @@ def benchmark( transform = ( transforms.Compose( [ - transforms.Resize((512, 512)), + transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC), ] ) if crop_and_resize @@ -79,11 +84,11 @@ def benchmark( ) # Create the ImageFolder dataset - dataset = datasets.ImageFolder(root=root_path, transform=transform) + dataset = datasets.ImageFolder(root=root_path, transform=transform, allow_empty=True) - # Create a DataLoader to load the images in batches + # Create a DataLoader to allow for multiple workers dataloader = DataLoader( - dataset, batch_size=32, shuffle=False, num_workers=8, collate_fn=lambda x: x + dataset, batch_size=1, shuffle=False, num_workers=8, collate_fn=lambda x: x ) # Iterate over the DataLoader diff --git a/python/test_datago_db.py b/python/test_datago_db.py new file mode 100644 index 0000000..369e215 --- /dev/null +++ b/python/test_datago_db.py @@ -0,0 +1,186 @@ +from datago import datago +import pytest +import os +import json +from go_types import go_array_to_pil_image, go_array_to_numpy + + +def get_test_source() -> str: + test_source = os.getenv("DATAROOM_TEST_SOURCE", "COYO") + assert test_source is not None, "Please set DATAROOM_TEST_SOURCE" + return test_source + + +def get_json_config(): + client_config = { + "source_type": datago.SourceTypeDB, + "source_config": { + "page_size": 10, + "sources": get_test_source(), + "require_images": True, + "has_masks": "segmentation_mask", + "has_latents": "masked_image", + "has_attributes": "caption_coca,caption_cogvlm,caption_moondream", + "return_latents": "masked_image", + "rank": 0, + "world_size": 1, + }, + "image_config": { + "crop_and_resize": False, + "default_image_size": 512, + "downsampling_ratio": 16, + "min_aspect_ratio": 0.5, + "max_aspect_ratio": 2.0, + "pre_encode_images": False, + }, + "prefetch_buffer_size": 64, + "samples_buffer_size": 128, + "concurrency": 1, + "limit": 10, + } + return client_config + + +def get_dataset(client_config: str): + client = datago.GetClientFromJSON(json.dumps(client_config)) + + class Dataset: + def __init__(self, client): + self.client = client + + def __iter__(self): + return self + + def __next__(self): + new_sample = self.client.GetSample() + if new_sample.ID == "": + raise StopIteration + return new_sample + + return Dataset(client) + + +def test_get_sample_db(): + # Check that we can instantiate a client and get a sample, nothing more + client_config = datago.GetDatagoConfig() + client_config.SamplesBufferSize = 10 + + source_config = datago.GetSourceDBConfig() + source_config.Sources = get_test_source() + client_config.SourceConfig = source_config + + client = datago.GetClient(client_config) + data = client.GetSample() + assert data.ID != "" + + +N_SAMPLES = 3 + + +def test_caption_and_image(): + client_config = get_json_config() + dataset = get_dataset(client_config) + + def check_image(img, channels=3): + assert img.Height > 0 + assert img.Width > 0 + + assert img.Height <= img.OriginalHeight + assert img.Width <= img.OriginalWidth + assert img.Channels == channels + + for i, sample in enumerate(dataset): + assert sample.Source != "" + assert sample.ID != "" + + assert len(sample.Attributes["caption_coca"]) != len( + sample.Attributes["caption_cogvlm"] + ), "Caption lengths should not be equal" + + check_image(sample.Image, 3) + check_image(sample.AdditionalImages["masked_image"], 3) + check_image(sample.Masks["segmentation_mask"], 1) + + # Check the image decoding + assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert ( + go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + ), "Image should be RGB" + assert ( + go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L" + ), "Mask should be L" + + if i > N_SAMPLES: + break + + +def test_image_resize(): + client_config = get_json_config() + client_config["image_config"]["crop_and_resize"] = True + dataset = get_dataset(client_config) + + for i, sample in enumerate(dataset): + # Assert that all the images in the sample have the same size + assert ( + sample.Image.Height + == sample.AdditionalImages["masked_image"].Height + == sample.Masks["segmentation_mask"].Height + and sample.Image.Height > 0 + ) + assert ( + sample.Image.Width + == sample.AdditionalImages["masked_image"].Width + == sample.Masks["segmentation_mask"].Width + and sample.Image.Width > 0 + ) + if i > N_SAMPLES: + break + + +def test_has_tags(): + client_config = get_json_config() + client_config["source_config"]["tags"] = "v4_trainset_hq" + + dataset = get_dataset(client_config) + sample = next(iter(dataset)) + + assert "v4_trainset_hq" in sample.Tags, "v4_trainset_hq should be in the tags" + + +def no_test_jpg_compression(): + # Check that the images are compressed as expected + client_config = get_json_config() + client_config["image_config"]["pre_encode_images"] = True + dataset = get_dataset(client_config) + + sample = next(iter(dataset)) + + assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert ( + go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + ), "Image should be RGB" + assert ( + go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L" + ), "Mask should be L" + + # Check the embeddings decoding + assert ( + go_array_to_numpy(sample.CocaEmbedding) is not None + ), "Embedding should be set" + + +def test_duplicate_state(): + client_config = get_json_config() + client_config["source_config"]["return_duplicate_state"] = True + dataset = get_dataset(client_config) + + sample = next(iter(dataset)) + assert sample.DuplicateState in [ + 0, + 1, + 2, + ], "Duplicate state should be 0, 1 or 2" + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/python/test_datago_filesystem.py b/python/test_datago_filesystem.py new file mode 100644 index 0000000..8d7ba6a --- /dev/null +++ b/python/test_datago_filesystem.py @@ -0,0 +1,28 @@ +import os +from PIL import Image +from datago import datago + + +# FIXME: Would need to generate more fake data to test this +def no_test_get_sample_filesystem(): + cwd = os.getcwd() + + try: + # Dump a sample image to the filesystem + img = Image.new("RGB", (100, 100)) + img.save(cwd + "/test.png") + + # Check that we can instantiate a client and get a sample, nothing more + client_config = datago.GetDatagoConfig() + client_config.SourceType = "filesystem" + client_config.SamplesBufferSize = 1 + + source_config = datago.SourceFileSystemConfig() + source_config.RootPath = cwd + source_config.PageSize = 1 + + client = datago.GetClient(client_config, source_config) + data = client.GetSample() + assert data.ID != "" + finally: + os.remove(cwd + "/test.png") diff --git a/python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc b/python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc deleted file mode 100644 index ba915fd034c4b88d76c319511a9895a37dec568a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4019 zcma)8&2QVt6(4?xq9n_XY$uT&CzkxxD0b{@c9XO~w`eww1Ebqy5oZteg-}e&kyjta zkp5f>Fp536h1>d4Yz{s6ut<^SFDQC1dXx+=>E1VwE%-u@99-x8CUQi&`|yq3$#k}Q$3x)_mg zL@sG%ZBgTGsT5s|a+JQPL#>oz<@jP8#?)`fV#167?J_l>Nvqq8ekG$G+Ak9;y3*6q zNle{3vn)T82zddYUqFJiz`tDV4R+nnNlY{IPJX!xt;l>izichFz`Rm`;`jCnkZ*}2 zL=mzg+NKm%vp{n4ebzsF`_b)(4<7v5cyxE+k+JaL%ZGRF=42*UZKf?-j#XKIPTr!H z_cv#rRLj;($tvX6Olxg(=3dd6@s~BTy5&5nR%RT_cI;N4fz0&k7V7~qoH6X`8eOtz z7f6f3XZHcwA+Hk2y`R<3PBs!#USg`QPEq_>p$#IMZ_tEOJZ*!|0Q3Srzto6J4)~b| zJwYADNyI_vtr8w{$x3iLhynM;FQ()YLkqajVGBtv2_JS%a-{=3nA2W|HP&QTUWozV&5WA5s}&{AMV)N0 zQbIzKQbE5P4ND4hzXfaGlxhlmk(z1}Fpb##H+Tyfc`m(|v!6bfwy%s|UfsH$v*Ta4 z8ktWs*t(Xpr{`zmwM^FMZ{%#IqFFpwm;%E{t|PhezB_N*7IpZvTL6R3Te)6}|CV*l zas$6ptrUvOEU{oYvuJ1~$L99Jx)!+cw8k~xlvrvZUtTR)_GfE_f<+gK+ZK!6vz$An zqE&Ht(+4nlewM|531$jLY2+zO8B6(6$)MIMU8^VDKpFl;J8!J{TcPXt>)@Kx|w48SgYjd@fujCyV z$>~g6@^_~#@%2%(f0jVC{=t0I&V16&eC#v6)SKKJi^dj$u@ogMru2!f{k z44n82;mElo6^;!tL2)4jZvzuHZE)|xMS^7Ew!^Lu-R^;K-ho7{!#E7?u51WL9+EBG zc6d%rafz!KJOM$8;Lc@3tjUZ3hXIk7kUZh0JKZXT1oc2T=U^_@9hxCrd`LCGX==5o ztGUruFwtEdf=Qd}ToLZljJi4m6L7I^^jI*7IpZI6uAmL`64uRwNxSU+0Ov-+e+Ys} zJ5Gd)P8hSGzykZ{Uf#n3Yv&fos$oDH(6{Dy2ux#WGOJiPZ=PJUy89{LH8?HEtVrdKP= zEM6!=pxCk61d^X1xr`(S zglW_&SJ$mvf+?^xrWDJ|6w_ap^h2ZJ#~M0`Gx*-7u!TRDP9vE?at(-$#ux+;sQnl+ z*cn4U93Q<66Yw_J|AK>r1UBBam;2r1-sH}mBegFw{<^zwulO|bbi2`g#p}MZt2Fh* zpY-4B`+fUgH&P=Febm!O>-y*q8tF}Ysj)_C!b?r;YE8ZSr9OD54>t56Pamr5L;oJS zysLToz^n8nFMXqtzU8HF?IxSUBfCieG%ro!=!NF!7}lBBsqkM|shornPmVqrqBKeq!f zbami?eXYm$V*8^{6OG={hCb%$WA)&D_XqjJp?B}z+31D;e0OSYko4s6x z7b}LbEni*FrI;QhNBnBC0mDzHrv2mvQk-dOZP20vsgdYn(N>>-v3aVe`SX0+Qofk~ zoEsKT^LU!V?-y#vG|MoUimN~#ub^R23^Y9O@|*fz|9QD;u9d7`(9dB2P22uEkRwHs zq$b|alb<>n=y;nXxx=3(iS6*GNzT^8?~x`+SC5DwdQ2KPB7&sPN#jREkefv9+TA`> k03t~vO|@rd?weoj_wR4k&tw{E)>E_f;Kk*Ra2Vg(k6cw>T>t<8 diff --git a/python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc b/python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc deleted file mode 100644 index 56f2a6a79005c10bbb040de61fc6bc485cb86545..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4152 zcma)8&2QVt6(4?xq9n_XV!11??bx!DP3X_vbd%jKy4@ym90YBgAWjeUg~~HQxLV8s89|=vNsY-e=4Knd-vY62mK)ZAaXx8Y~lV9>^hvs==Bv*S{ zHi<49=T`XV93d~^3kqo97TBkxUSrq1!ugV}+1iQ;t;F4mwqh)`z^oF0;`e48$T!4} zR1vZowRJ9TW`QXDyq}-F@#x0G`}gmuk8UqKQWx(3;^D2^3h(o#X-oWp%T-F7C_OFDEooGOK~m`z+OD-MbOup001&*V_r zUU$q+aD>%VD6fe+UrRWmllYRYu_if^^;f$ul*18;=BNvfw9M%uqL=ySu=rz7NwtJd;CiLzU*k#k$qVisU;O+9w=;g} z%G&l-#Y%tGYK(k7f~^k~Yx?eNx;9b__-l&gi+F=x2)+PfBv+AKd)KX5mO*Vc>=S@N z=O@Z}ivQN{nq^MnmRVgct@xP*!=A;y&)qSq2Gwkn z&fW}$u$$>87uYqm{L~%8zEv(6Rhu;f4L*N&)=&QwE>#=Xtru>t(CQ^W`jgY`Pn5lFLfzcYBumJAu9z^ivGpV;ADbL`>37yxSjc7 zzy!9gw=UYUBd~TL z0oZ{$?i4uE9spJ&haY#E%&|vqzg=q_hpC0bt5FEh6Tn4|K4yufJE$qHp(Z%Q5!7Rc zp8yWSF)J|*k%M=5aLMsqfB^9^jdYrPcN4s{%l*eQF+ZF00TKa zhvYmG{E-y@5=D~~lMSVx;DNl32r+bMKXruA`(1w^aCW9kz7Hd z0P!ViRLl)S$@l_H(icjV6^e;Coj@{)gslz#E1gC%gXBXb9|5uOI)?ED>LA1nRm~u2 zZi_OmJ_~nsVm% z@^9q*-u&u6Q3tr(!FOvHjT8WnY(>^TWL6$2VNrWf=8xC0;A>c9hg zD5v*Q`(sZt?)foSzUavp8{zx*xBRK1w{PEC=#2kxduncg{Bxjx?qk81Rkfm(s;auf zPi-hUUk>vsb}IdTHOR!KgA501)K_BZ=dtjz~o0F0TKGMC8H`xJySQBG-x7wYzg50EFX4 gn_^FW?(3iJ=l8c7=SEzy=!wNf_~LX&8H_FMzr96$p8x;= diff --git a/python/tests/test_datago.py b/python/tests/test_datago.py deleted file mode 100644 index 07543d6..0000000 --- a/python/tests/test_datago.py +++ /dev/null @@ -1,54 +0,0 @@ -from datago import datago -import pytest -import os -from PIL import Image - - -def get_test_source(): - return os.getenv("DATAROOM_TEST_SOURCE") - - -def test_get_sample_db(): - # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.DatagoConfig() - client_config.SetDefaults() - client_config.SamplesBufferSize = 10 - - source_config = datago.GeneratorDBConfig() - source_config.SetDefaults() - source_config.Sources = get_test_source() - - client = datago.GetClient(client_config) - data = client.GetSample() - assert data.ID != "" - - -def no_test_get_sample_filesystem(): - cwd = os.getcwd() - - try: - # Dump a sample image to the filesystem - img = Image.new("RGB", (100, 100)) - img.save(cwd + "/test.png") - - # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.DatagoConfig() - client_config.SetDefaults() - client_config.SourceType = "filesystem" - client_config.SamplesBufferSize = 1 - - source_config = datago.GeneratorFileSystemConfig() - source_config.RootPath = cwd - source_config.PageSize = 1 - - client = datago.GetClient(client_config, source_config) - data = client.GetSample() - assert data.ID != "" - finally: - os.remove(cwd + "/test.png") - - -# TODO: Backport all the image correctness tests - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/requirements-tests.txt b/requirements-tests.txt new file mode 100644 index 0000000..77f7cbd --- /dev/null +++ b/requirements-tests.txt @@ -0,0 +1,7 @@ +# Get core deps. +-r requirements.txt + +pytest +typer +tqdm +numpy diff --git a/requirements.txt b/requirements.txt index 33527cc..3868fb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -pytest pillow diff --git a/tests/client_test.go b/tests/client_test.go index 5e38cc1..aa5498a 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -15,15 +15,13 @@ func get_test_source() string { } func get_default_test_config() datago.DatagoConfig { - config := datago.DatagoConfig{} - config.SetDefaults() + config := datago.GetDatagoConfig() - db_config := datago.GeneratorDBConfig{} - db_config.SetDefaults() + db_config := datago.GetSourceDBConfig() db_config.Sources = get_test_source() db_config.PageSize = 32 - config.SourceConfig = db_config + config.SourceConfig = db_config return config } @@ -104,7 +102,7 @@ func TestExtraFields(t *testing.T) { clientConfig := get_default_test_config() clientConfig.SamplesBufferSize = 1 - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.HasLatents = "masked_image" dbConfig.HasMasks = "segmentation_mask" clientConfig.SourceConfig = dbConfig @@ -174,7 +172,7 @@ func TestImageBufferCompression(t *testing.T) { clientConfig.SamplesBufferSize = 1 clientConfig.ImageConfig.PreEncodeImages = true - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.HasLatents = "masked_image" dbConfig.HasMasks = "segmentation_mask" clientConfig.SourceConfig = dbConfig @@ -250,7 +248,7 @@ func TestRanks(t *testing.T) { clientConfig := get_default_test_config() clientConfig.SamplesBufferSize = 1 - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.WorldSize = 2 dbConfig.Rank = 0 clientConfig.SourceConfig = dbConfig