diff --git a/.github/workflows/gopy.yml b/.github/workflows/gopy.yml index a23a360..43fa4c8 100644 --- a/.github/workflows/gopy.yml +++ b/.github/workflows/gopy.yml @@ -64,5 +64,5 @@ jobs: run: | ls - python3 -m pip install -r requirements.txt - pytest -xv python/tests/* + python3 -m pip install -r requirements-tests.txt + pytest -xv python/* diff --git a/.gitignore b/.gitignore index 70d0b2d..878658d 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ go.work.sum build __pycache__ +*.pyc diff --git a/cmd/main.go b/cmd/main.go index e77b7aa..45ac8f7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,11 +12,13 @@ import ( func main() { // Define flags - config := datago.DatagoConfig{} - config.SetDefaults() + config := datago.GetDatagoConfig() - sourceConfig := datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")} + sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")} sourceConfig.PageSize = 10 + sourceConfig.Rank = 0 + sourceConfig.WorldSize = 1 + config.ImageConfig = datago.ImageTransformConfig{ DefaultImageSize: 1024, DownsamplingRatio: 32, @@ -26,8 +28,8 @@ func main() { config.Concurrency = *flag.Int("concurrency", 64, "The number of concurrent http requests to make") config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load") config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served") + config.Limit = *flag.Int("limit", 2000, "The number of items to fetch") - limit := flag.Int("limit", 2000, "The number of items to fetch") profile := flag.Bool("profile", false, "Whether to profile the code") // Parse the flags and instantiate the client @@ -65,10 +67,10 @@ func main() { // Fetch all of the binary payloads as they become available // NOTE: This is useless, just making sure that we empty the payloads channel - for i := 0; i < *limit; i++ { + for { sample := dataroom_client.GetSample() if sample.ID == "" { - fmt.Println("No more samples ", i, " samples served") + fmt.Println("No more samples") break } } @@ -78,7 +80,7 @@ func main() { // Calculate the elapsed time elapsedTime := time.Since(startTime) - fps := float64(*limit) / elapsedTime.Seconds() + fps := float64(config.Limit) / elapsedTime.Seconds() fmt.Printf("Total execution time: %.2f \n", elapsedTime.Seconds()) fmt.Printf("Average throughput: %.2f samples per second\n", fps) } diff --git a/pkg/architecture.go b/pkg/architecture.go index 6a17c70..f9b86fa 100644 --- a/pkg/architecture.go +++ b/pkg/architecture.go @@ -23,6 +23,7 @@ type Sample struct { ID string Source string Attributes map[string]interface{} + DuplicateState int Image ImagePayload Masks map[string]ImagePayload AdditionalImages map[string]ImagePayload diff --git a/pkg/client.go b/pkg/client.go index 6d215f4..4a714be 100644 --- a/pkg/client.go +++ b/pkg/client.go @@ -30,6 +30,7 @@ type DataSourceConfig struct { PageSize int `json:"page_size"` Rank int `json:"rank"` WorldSize int `json:"world_size"` + Limit int `json:"limit"` } type ImageTransformConfig struct { @@ -41,7 +42,7 @@ type ImageTransformConfig struct { PreEncodeImages bool `json:"pre_encode_images"` } -func (c *ImageTransformConfig) SetDefaults() { +func (c *ImageTransformConfig) setDefaults() { c.DefaultImageSize = 512 c.DownsamplingRatio = 16 c.MinAspectRatio = 0.5 @@ -57,17 +58,25 @@ type DatagoConfig struct { PrefetchBufferSize int `json:"prefetch_buffer_size"` SamplesBufferSize int `json:"samples_buffer_size"` Concurrency int `json:"concurrency"` + Limit int `json:"limit"` } -func (c *DatagoConfig) SetDefaults() { - dbConfig := GeneratorDBConfig{} - dbConfig.SetDefaults() +func (c *DatagoConfig) setDefaults() { + dbConfig := SourceDBConfig{} + dbConfig.setDefaults() c.SourceConfig = dbConfig - c.ImageConfig.SetDefaults() + c.ImageConfig.setDefaults() c.PrefetchBufferSize = 64 c.SamplesBufferSize = 32 c.Concurrency = 64 + c.Limit = 0 +} + +func GetDatagoConfig() DatagoConfig { + config := DatagoConfig{} + config.setDefaults() + return config } func DatagoConfigFromJSON(jsonString string) DatagoConfig { @@ -80,21 +89,30 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig { sourceConfig, err := json.Marshal(tempConfig["source_config"]) if err != nil { + fmt.Println("Error marshalling source_config", tempConfig["source_config"], err) log.Panicf("Error marshalling source_config: %v", err) } + // Unmarshal the source config based on the source type + // NOTE: The undefined fields will follow the default values switch tempConfig["source_type"] { case string(SourceTypeDB): - var dbConfig GeneratorDBConfig + dbConfig := SourceDBConfig{} + dbConfig.setDefaults() + err = json.Unmarshal(sourceConfig, &dbConfig) if err != nil { + fmt.Println("Error unmarshalling DB config", sourceConfig, err) log.Panicf("Error unmarshalling DB config: %v", err) } config.SourceConfig = dbConfig case string(SourceTypeFileSystem): - var fsConfig GeneratorFileSystemConfig + fsConfig := SourceFileSystemConfig{} + fsConfig.setDefaults() + err = json.Unmarshal(sourceConfig, &fsConfig) if err != nil { + fmt.Println("Error unmarshalling Filesystem config", sourceConfig, err) log.Panicf("Error unmarshalling FileSystem config: %v", err) } config.SourceConfig = fsConfig @@ -127,7 +145,9 @@ type DatagoClient struct { waitGroup *sync.WaitGroup cancel context.CancelFunc - ImageConfig ImageTransformConfig + imageConfig ImageTransformConfig + servedSamples int + limit int // Flexible generator, backend and dispatch goroutines generator Generator @@ -157,14 +177,14 @@ func GetClient(config DatagoConfig) *DatagoClient { fmt.Println(reflect.TypeOf(config.SourceConfig)) switch config.SourceConfig.(type) { - case GeneratorDBConfig: + case SourceDBConfig: fmt.Println("Creating a DB-backed dataloader") - dbConfig := config.SourceConfig.(GeneratorDBConfig) + dbConfig := config.SourceConfig.(SourceDBConfig) generator = newDatagoGeneratorDB(dbConfig) backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency} - case GeneratorFileSystemConfig: + case SourceFileSystemConfig: fmt.Println("Creating a FileSystem-backed dataloader") - fsConfig := config.SourceConfig.(GeneratorFileSystemConfig) + fsConfig := config.SourceConfig.(SourceFileSystemConfig) generator = newDatagoGeneratorFileSystem(fsConfig) backend = BackendFileSystem{config: &config, concurrency: config.Concurrency} default: @@ -177,7 +197,9 @@ func GetClient(config DatagoConfig) *DatagoClient { chanPages: make(chan Pages, 2), chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize), chanSamples: make(chan Sample, config.SamplesBufferSize), - ImageConfig: config.ImageConfig, + imageConfig: config.ImageConfig, + servedSamples: 0, + limit: config.Limit, context: nil, cancel: nil, waitGroup: nil, @@ -218,13 +240,13 @@ func (c *DatagoClient) Start() { // Optionally crop and resize the images and masks on the fly var arAwareTransform *ARAwareTransform = nil - if c.ImageConfig.CropAndResize { + if c.imageConfig.CropAndResize { fmt.Println("Cropping and resizing images") - fmt.Println("Base image size | downsampling ratio | min | max:", c.ImageConfig.DefaultImageSize, c.ImageConfig.DownsamplingRatio, c.ImageConfig.MinAspectRatio, c.ImageConfig.MaxAspectRatio) - arAwareTransform = newARAwareTransform(c.ImageConfig) + fmt.Println("Base image size | downsampling ratio | min | max:", c.imageConfig.DefaultImageSize, c.imageConfig.DownsamplingRatio, c.imageConfig.MinAspectRatio, c.imageConfig.MaxAspectRatio) + arAwareTransform = newARAwareTransform(c.imageConfig) } - if c.ImageConfig.PreEncodeImages { + if c.imageConfig.PreEncodeImages { fmt.Println("Pre-encoding images, we'll return serialized JPG and PNG bytes") } @@ -247,7 +269,7 @@ func (c *DatagoClient) Start() { wg.Add(1) go func() { defer wg.Done() - c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.ImageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them + c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.imageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them }() c.waitGroup = &wg @@ -255,15 +277,23 @@ func (c *DatagoClient) Start() { // Get a deserialized sample from the client func (c *DatagoClient) GetSample() Sample { - if c.cancel == nil { + if c.cancel == nil && c.servedSamples == 0 { fmt.Println("Dataroom client not started. Starting it on the first sample, this adds some initial latency") fmt.Println("Please consider starting the client in anticipation by calling .Start()") c.Start() } + if c.limit > 0 && c.servedSamples == c.limit { + fmt.Println("Reached the limit of samples to serve, stopping the client") + c.Stop() + return Sample{} + } + if sample, ok := <-c.chanSamples; ok { + c.servedSamples++ return sample } + fmt.Println("chanSamples closed, no more samples to serve") return Sample{} } diff --git a/pkg/generator_db.go b/pkg/generator_db.go index c6b0310..73150a3 100644 --- a/pkg/generator_db.go +++ b/pkg/generator_db.go @@ -29,6 +29,7 @@ type urlLatent struct { type dbSampleMetadata struct { Id string `json:"id"` Attributes map[string]interface{} `json:"attributes"` + DuplicateState int `json:"duplicate_state"` ImageDirectURL string `json:"image_direct_url"` Latents []urlLatent `json:"latents"` Tags []string `json:"tags"` @@ -73,7 +74,7 @@ type dbRequest struct { } // -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type GeneratorDBConfig struct { +type SourceDBConfig struct { DataSourceConfig Sources string `json:"sources"` RequireImages bool `json:"require_images"` @@ -86,7 +87,9 @@ type GeneratorDBConfig struct { LacksMasks string `json:"lacks_masks"` HasLatents string `json:"has_latents"` LacksLatents string `json:"lacks_latents"` - ReturnLatents string `json:"return_latents"` + + ReturnLatents string `json:"return_latents"` + ReturnDuplicateState bool `json:"return_duplicate_state"` MinShortEdge int `json:"min_short_edge"` MaxShortEdge int `json:"max_short_edge"` @@ -95,7 +98,7 @@ type GeneratorDBConfig struct { RandomSampling bool `json:"random_sampling"` } -func (c *GeneratorDBConfig) SetDefaults() { +func (c *SourceDBConfig) setDefaults() { c.PageSize = 512 c.Rank = -1 c.WorldSize = -1 @@ -112,15 +115,17 @@ func (c *GeneratorDBConfig) SetDefaults() { c.HasLatents = "" c.LacksLatents = "" c.ReturnLatents = "" + c.ReturnDuplicateState = false c.MinShortEdge = -1 c.MaxShortEdge = -1 c.MinPixelCount = -1 c.MaxPixelCount = -1 c.RandomSampling = false + } -func (c *GeneratorDBConfig) getDbRequest() dbRequest { +func (c *SourceDBConfig) getDbRequest() dbRequest { fields := "attributes,image_direct_url" if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 { @@ -142,6 +147,11 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { fmt.Println("Including embeddings") } + if c.ReturnDuplicateState { + fields += ",duplicate_state" + fmt.Println("Including duplicate state") + } + // Report some config data fmt.Println("Rank | World size:", c.Rank, c.WorldSize) fmt.Println("Sources:", c.Sources, "| Fields:", fields) @@ -153,6 +163,13 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { return fmt.Sprintf("%d", val) } + // Align rank and worldsize with the partitioning + if c.WorldSize < 2 { + // No partitioning + c.WorldSize = -1 + c.Rank = -1 + } + return dbRequest{ fields: fields, sources: c.Sources, @@ -176,12 +193,18 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest { } } +func GetSourceDBConfig() SourceDBConfig { + config := SourceDBConfig{} + config.setDefaults() + return config +} + type datagoGeneratorDB struct { baseRequest http.Request - config GeneratorDBConfig + config SourceDBConfig } -func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB { +func newDatagoGeneratorDB(config SourceDBConfig) datagoGeneratorDB { request := config.getDbRequest() api_key := os.Getenv("DATAROOM_API_KEY") diff --git a/pkg/generator_filesystem.go b/pkg/generator_filesystem.go index 5d2aec7..ad0dfba 100644 --- a/pkg/generator_filesystem.go +++ b/pkg/generator_filesystem.go @@ -18,12 +18,12 @@ type fsSampleMetadata struct { } // -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type GeneratorFileSystemConfig struct { +type SourceFileSystemConfig struct { DataSourceConfig RootPath string `json:"root_path"` } -func (c *GeneratorFileSystemConfig) SetDefaults() { +func (c *SourceFileSystemConfig) setDefaults() { c.PageSize = 512 c.Rank = 0 c.WorldSize = 1 @@ -31,12 +31,18 @@ func (c *GeneratorFileSystemConfig) SetDefaults() { c.RootPath = os.Getenv("DATAROOM_TEST_FILESYSTEM") } +func GetSourceFileSystemConfig() SourceFileSystemConfig { + config := SourceFileSystemConfig{} + config.setDefaults() + return config +} + type datagoGeneratorFileSystem struct { extensions set - config GeneratorFileSystemConfig + config SourceFileSystemConfig } -func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem { +func newDatagoGeneratorFileSystem(config SourceFileSystemConfig) datagoGeneratorFileSystem { supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"} var extensionsMap = make(set) for _, ext := range supported_img_extensions { diff --git a/pkg/serdes.go b/pkg/serdes.go index 57b535a..417b29a 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -185,7 +185,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware return nil, -1., err_report } -func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample { +func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample { // Per sample work: // - fetch the raw payloads // - deserialize / decode, depending on the types diff --git a/pkg/worker_http.go b/pkg/worker_http.go index e3588b3..bd8f668 100644 --- a/pkg/worker_http.go +++ b/pkg/worker_http.go @@ -7,7 +7,7 @@ import ( ) type BackendHTTP struct { - config *GeneratorDBConfig + config *SourceDBConfig concurrency int } diff --git a/python/benchmark_db.py b/python/benchmark_db.py index a648256..26b5601 100644 --- a/python/benchmark_db.py +++ b/python/benchmark_db.py @@ -1,10 +1,9 @@ from datago import datago # type: ignore import time -import typer from tqdm import tqdm import numpy as np from go_types import go_array_to_pil_image, go_array_to_numpy - +import typer def benchmark( source: str = typer.Option("SOURCE", help="The source to test out"), @@ -18,21 +17,24 @@ def benchmark( test_latents: bool = typer.Option(True, help="Test latents"), ): print(f"Running benchmark for {source} - {limit} samples") - client_config = datago.DatagoConfig() - client_config.SetDefaults() + + # Get a generic client config + client_config = datago.GetDatagoConfig() client_config.ImageConfig.CropAndResize = crop_and_resize - source_config = datago.GeneratorDBConfig() - source_config.SetDefaults() + # Specify the source parameters as you see fit + source_config = datago.GetSourceDBConfig() source_config.Sources = source source_config.RequireImages = require_images source_config.RequireEmbeddings = require_embeddings source_config.HasMasks = "segmentation_mask" if test_masks else "" source_config.HasLatents = "caption_latent_t5xxl" if test_latents else "" - client_config.SourceConfig = source_config + # Get a new client instance, happy benchmarking + client_config.SourceConfig = source_config client = datago.GetClient(client_config) - client.Start() + + client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance) start = time.time() # Make sure in the following that we compare apples to apples, meaning in that case diff --git a/python/benchmark_filesystem.py b/python/benchmark_filesystem.py index 6f26b17..266e416 100644 --- a/python/benchmark_filesystem.py +++ b/python/benchmark_filesystem.py @@ -1,11 +1,10 @@ from datago import datago # type: ignore import time -import typer from tqdm import tqdm from go_types import go_array_to_pil_image import os import json - +import typer def benchmark( root_path: str = typer.Option( @@ -15,7 +14,7 @@ def benchmark( crop_and_resize: bool = typer.Option( True, help="Crop and resize the images on the fly" ), - concurrency: int = typer.Option(32, help="The number of concurrent workers"), + concurrency: int = typer.Option(64, help="The number of coroutines"), compare_torch: bool = typer.Option(True, help="Compare against torch dataloader"), ): print(f"Running benchmark for {root_path} - {limit} samples") @@ -38,11 +37,13 @@ def benchmark( "prefetch_buffer_size": 64, "samples_buffer_size": 128, "concurrency": concurrency, + "limit": limit, } client = datago.GetClientFromJSON(json.dumps(client_config)) - client.Start() + start = time.time() + client.Start() # Make sure in the following that we compare apples to apples, meaning in that case # that we materialize the payloads in the python scope in the expected format @@ -53,6 +54,10 @@ def benchmark( if sample.ID and hasattr(sample, "Image"): img = go_array_to_pil_image(sample.Image) + if sample.ID is None: + print("No more samples") + break + fps = limit / (time.time() - start) print(f"Datago FPS {fps:.2f}") client.Stop() @@ -71,7 +76,7 @@ def benchmark( transform = ( transforms.Compose( [ - transforms.Resize((512, 512)), + transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC), ] ) if crop_and_resize @@ -79,11 +84,11 @@ def benchmark( ) # Create the ImageFolder dataset - dataset = datasets.ImageFolder(root=root_path, transform=transform) + dataset = datasets.ImageFolder(root=root_path, transform=transform, allow_empty=True) - # Create a DataLoader to load the images in batches + # Create a DataLoader to allow for multiple workers dataloader = DataLoader( - dataset, batch_size=32, shuffle=False, num_workers=8, collate_fn=lambda x: x + dataset, batch_size=1, shuffle=False, num_workers=8, collate_fn=lambda x: x ) # Iterate over the DataLoader diff --git a/python/test_datago_db.py b/python/test_datago_db.py new file mode 100644 index 0000000..369e215 --- /dev/null +++ b/python/test_datago_db.py @@ -0,0 +1,186 @@ +from datago import datago +import pytest +import os +import json +from go_types import go_array_to_pil_image, go_array_to_numpy + + +def get_test_source() -> str: + test_source = os.getenv("DATAROOM_TEST_SOURCE", "COYO") + assert test_source is not None, "Please set DATAROOM_TEST_SOURCE" + return test_source + + +def get_json_config(): + client_config = { + "source_type": datago.SourceTypeDB, + "source_config": { + "page_size": 10, + "sources": get_test_source(), + "require_images": True, + "has_masks": "segmentation_mask", + "has_latents": "masked_image", + "has_attributes": "caption_coca,caption_cogvlm,caption_moondream", + "return_latents": "masked_image", + "rank": 0, + "world_size": 1, + }, + "image_config": { + "crop_and_resize": False, + "default_image_size": 512, + "downsampling_ratio": 16, + "min_aspect_ratio": 0.5, + "max_aspect_ratio": 2.0, + "pre_encode_images": False, + }, + "prefetch_buffer_size": 64, + "samples_buffer_size": 128, + "concurrency": 1, + "limit": 10, + } + return client_config + + +def get_dataset(client_config: str): + client = datago.GetClientFromJSON(json.dumps(client_config)) + + class Dataset: + def __init__(self, client): + self.client = client + + def __iter__(self): + return self + + def __next__(self): + new_sample = self.client.GetSample() + if new_sample.ID == "": + raise StopIteration + return new_sample + + return Dataset(client) + + +def test_get_sample_db(): + # Check that we can instantiate a client and get a sample, nothing more + client_config = datago.GetDatagoConfig() + client_config.SamplesBufferSize = 10 + + source_config = datago.GetSourceDBConfig() + source_config.Sources = get_test_source() + client_config.SourceConfig = source_config + + client = datago.GetClient(client_config) + data = client.GetSample() + assert data.ID != "" + + +N_SAMPLES = 3 + + +def test_caption_and_image(): + client_config = get_json_config() + dataset = get_dataset(client_config) + + def check_image(img, channels=3): + assert img.Height > 0 + assert img.Width > 0 + + assert img.Height <= img.OriginalHeight + assert img.Width <= img.OriginalWidth + assert img.Channels == channels + + for i, sample in enumerate(dataset): + assert sample.Source != "" + assert sample.ID != "" + + assert len(sample.Attributes["caption_coca"]) != len( + sample.Attributes["caption_cogvlm"] + ), "Caption lengths should not be equal" + + check_image(sample.Image, 3) + check_image(sample.AdditionalImages["masked_image"], 3) + check_image(sample.Masks["segmentation_mask"], 1) + + # Check the image decoding + assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert ( + go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + ), "Image should be RGB" + assert ( + go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L" + ), "Mask should be L" + + if i > N_SAMPLES: + break + + +def test_image_resize(): + client_config = get_json_config() + client_config["image_config"]["crop_and_resize"] = True + dataset = get_dataset(client_config) + + for i, sample in enumerate(dataset): + # Assert that all the images in the sample have the same size + assert ( + sample.Image.Height + == sample.AdditionalImages["masked_image"].Height + == sample.Masks["segmentation_mask"].Height + and sample.Image.Height > 0 + ) + assert ( + sample.Image.Width + == sample.AdditionalImages["masked_image"].Width + == sample.Masks["segmentation_mask"].Width + and sample.Image.Width > 0 + ) + if i > N_SAMPLES: + break + + +def test_has_tags(): + client_config = get_json_config() + client_config["source_config"]["tags"] = "v4_trainset_hq" + + dataset = get_dataset(client_config) + sample = next(iter(dataset)) + + assert "v4_trainset_hq" in sample.Tags, "v4_trainset_hq should be in the tags" + + +def no_test_jpg_compression(): + # Check that the images are compressed as expected + client_config = get_json_config() + client_config["image_config"]["pre_encode_images"] = True + dataset = get_dataset(client_config) + + sample = next(iter(dataset)) + + assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert ( + go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + ), "Image should be RGB" + assert ( + go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L" + ), "Mask should be L" + + # Check the embeddings decoding + assert ( + go_array_to_numpy(sample.CocaEmbedding) is not None + ), "Embedding should be set" + + +def test_duplicate_state(): + client_config = get_json_config() + client_config["source_config"]["return_duplicate_state"] = True + dataset = get_dataset(client_config) + + sample = next(iter(dataset)) + assert sample.DuplicateState in [ + 0, + 1, + 2, + ], "Duplicate state should be 0, 1 or 2" + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/python/test_datago_filesystem.py b/python/test_datago_filesystem.py new file mode 100644 index 0000000..8d7ba6a --- /dev/null +++ b/python/test_datago_filesystem.py @@ -0,0 +1,28 @@ +import os +from PIL import Image +from datago import datago + + +# FIXME: Would need to generate more fake data to test this +def no_test_get_sample_filesystem(): + cwd = os.getcwd() + + try: + # Dump a sample image to the filesystem + img = Image.new("RGB", (100, 100)) + img.save(cwd + "/test.png") + + # Check that we can instantiate a client and get a sample, nothing more + client_config = datago.GetDatagoConfig() + client_config.SourceType = "filesystem" + client_config.SamplesBufferSize = 1 + + source_config = datago.SourceFileSystemConfig() + source_config.RootPath = cwd + source_config.PageSize = 1 + + client = datago.GetClient(client_config, source_config) + data = client.GetSample() + assert data.ID != "" + finally: + os.remove(cwd + "/test.png") diff --git a/python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc b/python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc deleted file mode 100644 index ba915fd..0000000 Binary files a/python/tests/__pycache__/datago_test.cpython-311-pytest-8.3.3.pyc and /dev/null differ diff --git a/python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc b/python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc deleted file mode 100644 index 56f2a6a..0000000 Binary files a/python/tests/__pycache__/test_datago.cpython-311-pytest-8.3.3.pyc and /dev/null differ diff --git a/python/tests/test_datago.py b/python/tests/test_datago.py deleted file mode 100644 index 07543d6..0000000 --- a/python/tests/test_datago.py +++ /dev/null @@ -1,54 +0,0 @@ -from datago import datago -import pytest -import os -from PIL import Image - - -def get_test_source(): - return os.getenv("DATAROOM_TEST_SOURCE") - - -def test_get_sample_db(): - # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.DatagoConfig() - client_config.SetDefaults() - client_config.SamplesBufferSize = 10 - - source_config = datago.GeneratorDBConfig() - source_config.SetDefaults() - source_config.Sources = get_test_source() - - client = datago.GetClient(client_config) - data = client.GetSample() - assert data.ID != "" - - -def no_test_get_sample_filesystem(): - cwd = os.getcwd() - - try: - # Dump a sample image to the filesystem - img = Image.new("RGB", (100, 100)) - img.save(cwd + "/test.png") - - # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.DatagoConfig() - client_config.SetDefaults() - client_config.SourceType = "filesystem" - client_config.SamplesBufferSize = 1 - - source_config = datago.GeneratorFileSystemConfig() - source_config.RootPath = cwd - source_config.PageSize = 1 - - client = datago.GetClient(client_config, source_config) - data = client.GetSample() - assert data.ID != "" - finally: - os.remove(cwd + "/test.png") - - -# TODO: Backport all the image correctness tests - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/requirements-tests.txt b/requirements-tests.txt new file mode 100644 index 0000000..77f7cbd --- /dev/null +++ b/requirements-tests.txt @@ -0,0 +1,7 @@ +# Get core deps. +-r requirements.txt + +pytest +typer +tqdm +numpy diff --git a/requirements.txt b/requirements.txt index 33527cc..3868fb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -pytest pillow diff --git a/tests/client_test.go b/tests/client_test.go index 5e38cc1..aa5498a 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -15,15 +15,13 @@ func get_test_source() string { } func get_default_test_config() datago.DatagoConfig { - config := datago.DatagoConfig{} - config.SetDefaults() + config := datago.GetDatagoConfig() - db_config := datago.GeneratorDBConfig{} - db_config.SetDefaults() + db_config := datago.GetSourceDBConfig() db_config.Sources = get_test_source() db_config.PageSize = 32 - config.SourceConfig = db_config + config.SourceConfig = db_config return config } @@ -104,7 +102,7 @@ func TestExtraFields(t *testing.T) { clientConfig := get_default_test_config() clientConfig.SamplesBufferSize = 1 - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.HasLatents = "masked_image" dbConfig.HasMasks = "segmentation_mask" clientConfig.SourceConfig = dbConfig @@ -174,7 +172,7 @@ func TestImageBufferCompression(t *testing.T) { clientConfig.SamplesBufferSize = 1 clientConfig.ImageConfig.PreEncodeImages = true - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.HasLatents = "masked_image" dbConfig.HasMasks = "segmentation_mask" clientConfig.SourceConfig = dbConfig @@ -250,7 +248,7 @@ func TestRanks(t *testing.T) { clientConfig := get_default_test_config() clientConfig.SamplesBufferSize = 1 - dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig) + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) dbConfig.WorldSize = 2 dbConfig.Rank = 0 clientConfig.SourceConfig = dbConfig