diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7b0e5a4..ecea76b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -26,7 +26,7 @@ jobs: sudo apt-get -y install libvips-dev - name: Build - run: cd src/cmd/main && go build -v main.go + run: cd src/cmd && go build -v main.go - name: Test env: diff --git a/.github/workflows/gopy.yml b/.github/workflows/gopy.yml index f2694f7..3e405cc 100644 --- a/.github/workflows/gopy.yml +++ b/.github/workflows/gopy.yml @@ -38,16 +38,17 @@ jobs: - name: Build python module run: | - cd src/pkg/client + cd src/pkg gopy pkg -author="Photoroom" -email="team@photoroom.com" -name="datago" . - export DESTINATION="../../../build" + export DESTINATION="../../build" mkdir -p $DESTINATION/datago mv datago/* $DESTINATION/datago/. mv setup.py $DESTINATION/. mv Makefile $DESTINATION/. mv README.md $DESTINATION/. rm LICENSE MANIFEST.in - cd ../../../build + ls ../.. + cd ../../build - name: Install python module run: | diff --git a/README.md b/README.md index a3344a7..0bb6ccb 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ NOTE: - - Either `export PATH=$PATH:~/go/bin` or add it to your .bashrc - you may need this to make sure that LDD looks at the current folder `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:.` -then from the /pkg/client folder: +then from the /pkg folder: ```bash $ gopy pkg -author="Photoroom" -email="team@photoroom.com" -url="" -name="datago" -version="0.0.1" . diff --git a/generate_python_package.sh b/generate_python_package.sh index 2839eb2..8d70361 100755 --- a/generate_python_package.sh +++ b/generate_python_package.sh @@ -11,7 +11,7 @@ DESTINATION="../../../python_$python_version" rm -rf $DESTINATION # Build the python package via the gopy toolchain -cd src/pkg/client +cd src/pkg gopy pkg -author="Photoroom" -email="team@photoroom.com" -url="" -name="datago" -version="0.3" . mkdir -p $DESTINATION/datago mv datago/* $DESTINATION/datago/. @@ -21,6 +21,6 @@ mv README.md $DESTINATION/. rm LICENSE rm MANIFEST.in -cd ../../.. +cd ../.. diff --git a/python_tests/datago_test.py b/python_tests/datago_test.py index 921312c..ef305f2 100644 --- a/python_tests/datago_test.py +++ b/python_tests/datago_test.py @@ -1,13 +1,14 @@ from datago import datago import pytest import os +from PIL import Image def get_test_source(): return os.getenv("DATAROOM_TEST_SOURCE") -def test_get_sample(): +def test_get_sample_db(): # Check that we can instantiate a client and get a sample, nothing more config = datago.GetDefaultConfig() config.source = get_test_source() @@ -17,6 +18,27 @@ def test_get_sample(): assert data.ID != "" +def test_get_sample_filesystem(): + cwd = os.getcwd() + + try: + # Dump a sample image to the filesystem + img = Image.new("RGB", (100, 100)) + img.save(cwd + "/test.png") + + # Check that we can instantiate a client and get a sample, nothing more + config = datago.GetDefaultConfig() + config.SourceType = "filesystem" + config.Sources = cwd + config.sample = 1 + + client = datago.GetClient(config) + data = client.GetSample() + assert data.ID != "" + finally: + os.remove(cwd + "/test.png") + + # TODO: Backport all the image correctness tests if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 55b033e..69d58e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -pytest \ No newline at end of file +pytest +pillow \ No newline at end of file diff --git a/src/cmd/main/main.go b/src/cmd/main.go similarity index 62% rename from src/cmd/main/main.go rename to src/cmd/main.go index 6307f60..e8241b8 100644 --- a/src/cmd/main/main.go +++ b/src/cmd/main.go @@ -1,7 +1,7 @@ package main import ( - datago "datago/pkg/client" + datago "datago/pkg" "flag" "fmt" "os" @@ -13,6 +13,8 @@ import ( func main() { // Define flags client_config := datago.GetDefaultConfig() + client_config.SourceType = datago.SourceTypeFileSystem + client_config.SourceConfig = datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM"), PageSize: 10} client_config.DefaultImageSize = 1024 client_config.DownsamplingRatio = 32 @@ -21,20 +23,6 @@ func main() { client_config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load") client_config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served") - client_config.Sources = *flag.String("source", "GETTY", "The source for the items") - client_config.RequireImages = *flag.Bool("require_images", true, "Whether the items require images") - client_config.RequireEmbeddings = *flag.Bool("require_embeddings", false, "Whether the items require the DB embeddings") - - client_config.Tags = *flag.String("tags", "", "The tags to filter for") - client_config.TagsNE = *flag.String("tags__ne", "", "The tags that the samples should not have") - client_config.HasMasks = *flag.String("has_masks", "", "The masks to filter for") - client_config.HasLatents = *flag.String("has_latents", "", "The masks to filter for") - client_config.HasAttributes = *flag.String("has_attributes", "", "The attributes to filter for") - - client_config.LacksMasks = *flag.String("lacks_masks", "", "The masks to filter against") - client_config.LacksLatents = *flag.String("lacks_latents", "", "The masks to filter against") - client_config.LacksAttributes = *flag.String("lacks_attributes", "", "The attributes to filter against") - limit := flag.Int("limit", 2000, "The number of items to fetch") profile := flag.Bool("profile", false, "Whether to profile the code") @@ -71,7 +59,7 @@ func main() { for i := 0; i < *limit; i++ { sample := dataroom_client.GetSample() if sample.ID == "" { - fmt.Printf("Error fetching sample") + fmt.Println("No more samples ", i, " samples served") break } } @@ -83,5 +71,5 @@ func main() { elapsedTime := time.Since(startTime) fps := float64(*limit) / elapsedTime.Seconds() fmt.Printf("Total execution time: %.2f \n", elapsedTime.Seconds()) - fmt.Printf("Average fetch rate: %.2f fetches per second\n", fps) + fmt.Printf("Average throughput: %.2f samples per second\n", fps) } diff --git a/src/pkg/architecture.go b/src/pkg/architecture.go new file mode 100644 index 0000000..6a17c70 --- /dev/null +++ b/src/pkg/architecture.go @@ -0,0 +1,52 @@ +package datago + +import "context" + +// --- Sample data structures - these will be exposed to the Python world --------------------------------------------------------------------------------------------------------------------------------------------------------------- +type LatentPayload struct { + Data []byte + Len int + DataPtr uintptr +} + +type ImagePayload struct { + Data []byte + OriginalHeight int // Good indicator of the image frequency dbResponse at the current resolution + OriginalWidth int + Height int // Useful to decode the current payload + Width int + Channels int + DataPtr uintptr +} + +type Sample struct { + ID string + Source string + Attributes map[string]interface{} + Image ImagePayload + Masks map[string]ImagePayload + AdditionalImages map[string]ImagePayload + Latents map[string]LatentPayload + CocaEmbedding []float32 + Tags []string +} + +// --- Generator and Backend interfaces --------------------------------------------------------------------------------------------------------------------------------------------------------------- + +// The generator will be responsible for producing pages of metadata which can be dispatched +// to the dispatch goroutine. The metadata will be used to fetch the actual payloads + +type SampleDataPointers interface{} + +type Pages struct { + samplesDataPointers []SampleDataPointers +} + +type Generator interface { + generatePages(ctx context.Context, chanPages chan Pages) +} + +// The backend will be responsible for fetching the payloads and deserializing them +type Backend interface { + collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, pre_encode_images bool) +} diff --git a/src/pkg/client/client.go b/src/pkg/client.go similarity index 57% rename from src/pkg/client/client.go rename to src/pkg/client.go index e444e65..36baca3 100644 --- a/src/pkg/client/client.go +++ b/src/pkg/client.go @@ -4,82 +4,36 @@ import ( "context" "fmt" "log" - "net/http" "os" "runtime" "runtime/debug" - "strings" "sync" "github.com/davidbyttow/govips/v2/vips" ) -// --- Sample data structures - these will be exposed to the Python world --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type LatentPayload struct { - Data []byte - Len int - DataPtr uintptr -} - -type ImagePayload struct { - Data []byte - OriginalHeight int // Good indicator of the image frequency dbResponse at the current resolution - OriginalWidth int - Height int // Useful to decode the current payload - Width int - Channels int - DataPtr uintptr -} - -type Sample struct { - ID string - Source string - Attributes map[string]interface{} - Image ImagePayload - Masks map[string]ImagePayload - AdditionalImages map[string]ImagePayload - Latents map[string]LatentPayload - CocaEmbedding []float32 - Tags []string -} - -type URLPayload struct { - url string - content []byte -} - // ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- // Public interface for this package, will be reflected in the python bindings type DatagoSourceType string const ( - SourceTypeDB DatagoSourceType = "DB" - SourceTypeLocalStorage DatagoSourceType = "LocalStorage" + SourceTypeDB DatagoSourceType = "DB" + SourceTypeFileSystem DatagoSourceType = "FileSystem" // incoming: object storage ) +type DataSourceConfig interface{} + type DatagoConfig struct { - Sources string SourceType DatagoSourceType - RequireImages bool - RequireEmbeddings bool - Tags string - TagsNE string - HasAttributes string - LacksAttributes string - HasMasks string - LacksMasks string - HasLatents string - LacksLatents string + SourceConfig DataSourceConfig CropAndResize bool DefaultImageSize int DownsamplingRatio int MinAspectRatio float64 MaxAspectRatio float64 PreEncodeImages bool - Rank uint32 - WorldSize uint32 PrefetchBufferSize int SamplesBufferSize int ConcurrentDownloads int @@ -88,21 +42,11 @@ type DatagoConfig struct { type DatagoClient struct { concurrency int - baseRequest http.Request context context.Context waitGroup *sync.WaitGroup cancel context.CancelFunc - // Request parameters - sources string - require_images bool - require_embeddings bool - has_masks []string - has_latents []string - rank uint32 - world_size uint32 - // Online transform parameters crop_and_resize bool default_image_size int @@ -111,72 +55,53 @@ type DatagoClient struct { max_aspect_ratio float64 pre_encode_images bool - // Flexible frontend, backend and dispatch goroutines - frontend Frontend - backend Backend + // Flexible generator, backend and dispatch goroutines + generator Generator + backend Backend // Channels - these will be used to communicate between the background goroutines - chanPageResults chan dbResponse // TODO: Make this a generic type - chanSampleMetadata chan dbSampleMetadata // TODO: Make this a generic type + chanPages chan Pages + chanSampleMetadata chan SampleDataPointers chanSamples chan Sample } -// ----------------------------------------------------------------------------------------------------------------- -// Define the interfaces that the different features will needd to follow - -// The frontend will be responsible for producing pages of metadata which can be dispatched -// to the dispatch goroutine. The metadata will be used to fetch the actual payloads - -type Frontend interface { - collectPages(ctx context.Context, chanPageResults chan dbResponse) -} - -// The backend will be responsible for fetching the payloads and deserializing them -type Backend interface { - collectSamples(chanSampleMetadata chan dbSampleMetadata, chanSamples chan Sample, transform *ARAwareTransform) -} - // ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- func GetDefaultConfig() DatagoConfig { + dbConfig := GetDefaultDBConfig() + return DatagoConfig{ - Sources: "", - SourceType: SourceTypeDB, - RequireImages: true, - RequireEmbeddings: false, - Tags: "", - TagsNE: "", - HasAttributes: "", - LacksAttributes: "", - HasMasks: "", - LacksMasks: "", - HasLatents: "", - LacksLatents: "", - CropAndResize: false, - DefaultImageSize: 512, - DownsamplingRatio: 16, - MinAspectRatio: 0.5, - MaxAspectRatio: 2.0, - PreEncodeImages: false, - Rank: 0, - WorldSize: 1, - PrefetchBufferSize: 8, - SamplesBufferSize: 8, - ConcurrentDownloads: 2, - PageSize: 1000, + SourceType: SourceTypeDB, + SourceConfig: dbConfig, + CropAndResize: false, + DefaultImageSize: 512, + DownsamplingRatio: 16, + MinAspectRatio: 0.5, + MaxAspectRatio: 2.0, + PreEncodeImages: false, + PrefetchBufferSize: 8, + SamplesBufferSize: 8, + PageSize: 20, // 1000 for a vectorDB, make this a default which depends on the source type } } // Create a new Dataroom Client func GetClient(config DatagoConfig) *DatagoClient { - // Create the frontend and backend - var frontend Frontend + // Create the generator and backend + var generator Generator var backend Backend if config.SourceType == SourceTypeDB { - frontend = newDatagoFrontendDB(config) - backend = BackendHTTP{config: &config} + fmt.Println("Creating a DB-backed dataloader") + db_config := config.SourceConfig.(GeneratorDBConfig) + generator = newDatagoGeneratorDB(db_config) + backend = BackendHTTP{config: &db_config} + } else if config.SourceType == SourceTypeFileSystem { + fmt.Println("Creating a FileSystem-backed dataloader") + db_config := config.SourceConfig.(GeneratorFileSystemConfig) + generator = newDatagoGeneratorFileSystem(db_config) + backend = BackendFileSystem{config: &config} } else { // TODO: Handle other sources log.Panic("Unsupported source type at the moment") @@ -185,26 +110,20 @@ func GetClient(config DatagoConfig) *DatagoClient { // Create the client client := &DatagoClient{ concurrency: config.ConcurrentDownloads, - chanPageResults: make(chan dbResponse, 2), - chanSampleMetadata: make(chan dbSampleMetadata, config.PrefetchBufferSize), + chanPages: make(chan Pages, 2), + chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize), chanSamples: make(chan Sample, config.SamplesBufferSize), - require_images: config.RequireImages, - require_embeddings: config.RequireEmbeddings, - has_masks: strings.Split(config.HasMasks, ","), - has_latents: strings.Split(config.HasLatents, ","), + crop_and_resize: config.CropAndResize, default_image_size: config.DefaultImageSize, downsampling_ratio: config.DownsamplingRatio, min_aspect_ratio: config.MinAspectRatio, max_aspect_ratio: config.MaxAspectRatio, pre_encode_images: config.PreEncodeImages, - sources: config.Sources, - rank: config.Rank, - world_size: config.WorldSize, context: nil, cancel: nil, waitGroup: nil, - frontend: frontend, + generator: generator, backend: backend, } @@ -260,7 +179,7 @@ func (c *DatagoClient) Start() { wg.Add(1) go func() { defer wg.Done() - c.frontend.collectPages(c.context, c.chanPageResults) // Collect the root data source pages + c.generator.generatePages(c.context, c.chanPages) // Collect the root data source pages }() wg.Add(1) @@ -272,7 +191,7 @@ func (c *DatagoClient) Start() { wg.Add(1) go func() { defer wg.Done() - c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform) // Fetch the payloads and and deserialize them + c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.pre_encode_images) // Fetch the payloads and and deserialize them }() c.waitGroup = &wg @@ -289,7 +208,7 @@ func (c *DatagoClient) GetSample() Sample { if sample, ok := <-c.chanSamples; ok { return sample } - + fmt.Println("chanSamples closed, no more samples to serve") return Sample{} } @@ -304,7 +223,7 @@ func (c *DatagoClient) Stop() { c.cancel() // Clear the channels, in case a commit is blocking - go consumeChannel(c.chanPageResults) + go consumeChannel(c.chanPages) go consumeChannel(c.chanSampleMetadata) go consumeChannel(c.chanSamples) @@ -330,21 +249,14 @@ func (c *DatagoClient) asyncDispatch() { fmt.Println("Metadata fetch goroutine wrapping up") close(c.chanSampleMetadata) return - case page, open := <-c.chanPageResults: + case page, open := <-c.chanPages: if !open { fmt.Println("No more metadata to fetch, wrapping up") close(c.chanSampleMetadata) return } - for _, item := range page.DBSampleMetadata { - // Skip the sample if multi-rank is enabled and the rank is not the one we're interested in - // NOTE: if the front end is a DB, this is a wasteful way to distribute the work - // since we waste most of the page we fetched - if c.world_size > 1 && computeFNVHash32(item.Id)%c.world_size != c.rank { - continue - } - + for _, item := range page.samplesDataPointers { select { case <-c.context.Done(): fmt.Println("Metadata fetch goroutine wrapping up") diff --git a/src/pkg/client/utils.go b/src/pkg/client/utils.go deleted file mode 100644 index 98a5d47..0000000 --- a/src/pkg/client/utils.go +++ /dev/null @@ -1,57 +0,0 @@ -package datago - -import ( - "hash/fnv" - "time" - "unsafe" -) - -func computeFNVHash32(input string) uint32 { - // Create a new FNV-1a 32-bit hash - hasher := fnv.New32a() - - // Write data to the hash - hasher.Write([]byte(input)) - - // Compute the hash and return it as an integer - return hasher.Sum32() -} - -func dataPtrFromSlice(a []uint8) uintptr { - if len(a) == 0 { - return 0 - } - return uintptr(unsafe.Pointer(&a[0])) -} - -func exponentialBackoffWait(retries int) { - baseDelay := time.Second - maxDelay := 64 * time.Second - - // Calculate the delay with exponential backoff - delay := baseDelay * (1 << uint(retries)) - if delay > maxDelay { - delay = maxDelay - } - time.Sleep(delay) -} - -func getLast5Chars(s string) string { - runes := []rune(s) - if len(runes) <= 5 { - return s - } - return string(runes[len(runes)-5:]) -} - -func sanitizeStr(optional_str *string) string { - if optional_str == nil { - return "" - } - return *optional_str -} - -func consumeChannel[T any](ch <-chan T) { - for range ch { - } -} diff --git a/src/pkg/client/frontend_db.go b/src/pkg/generator_db.go similarity index 61% rename from src/pkg/client/frontend_db.go rename to src/pkg/generator_db.go index 00df77c..3d38560 100644 --- a/src/pkg/client/frontend_db.go +++ b/src/pkg/generator_db.go @@ -8,13 +8,19 @@ import ( "log" "net/http" "os" + "strings" "time" ) // Interact with a DB to get payloads and process them -// Define a frontend and a backend goroutine +// Define a generator and a backend goroutine // --- DB Communication structures --------------------------------------------------------------------------------------------------------------------------------------------------------------- +type urlPayload struct { + url string + content []byte +} + type urlLatent struct { URL string `json:"file_direct_url"` LatentType string `json:"latent_type"` @@ -55,20 +61,40 @@ type dbRequest struct { lacksLatents string } -func (c *DatagoConfig) getDbRequest() dbRequest { +// -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- +type GeneratorDBConfig struct { + // Request parameters + Sources string + RequireImages bool + RequireEmbeddings bool + Tags []string + TagsNE []string + HasAttributes []string + LacksAttributes []string + HasMasks []string + LacksMasks []string + HasLatents []string + LacksLatents []string + ConcurrentDownloads int + PageSize int + Rank uint32 + WorldSize uint32 +} + +func (c *GeneratorDBConfig) getDbRequest() dbRequest { fields := "attributes,image_direct_url" - if c.HasLatents != "" || c.HasMasks != "" { + if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 { fields += ",latents" fmt.Println("Including some latents:", c.HasLatents, c.HasMasks) } - if c.Tags != "" { + if len(c.Tags) > 0 { fields += ",tags" fmt.Println("Including some tags:", c.Tags) } - if c.HasLatents != "" { + if len(c.HasLatents) > 0 { fmt.Println("Including some attributes:", c.HasLatents) } @@ -83,26 +109,45 @@ func (c *DatagoConfig) getDbRequest() dbRequest { return dbRequest{ fields: fields, - sources: sanitizeStr(&c.Sources), + sources: c.Sources, pageSize: fmt.Sprintf("%d", c.PageSize), - tags: sanitizeStr(&c.Tags), - tagsNE: sanitizeStr(&c.TagsNE), - hasAttributes: sanitizeStr(&c.HasAttributes), - lacksAttributes: sanitizeStr(&c.LacksAttributes), - hasMasks: sanitizeStr(&c.HasMasks), - lacksMasks: sanitizeStr(&c.LacksMasks), - hasLatents: sanitizeStr(&c.HasLatents), - lacksLatents: sanitizeStr(&c.LacksLatents), + tags: strings.Join(c.Tags, ","), + tagsNE: strings.Join(c.TagsNE, ","), + hasAttributes: strings.Join(c.HasAttributes, ","), + lacksAttributes: strings.Join(c.LacksAttributes, ","), + hasMasks: strings.Join(c.HasMasks, ","), + lacksMasks: strings.Join(c.LacksMasks, ","), + hasLatents: strings.Join(c.HasLatents, ","), + lacksLatents: strings.Join(c.LacksLatents, ","), } } -// -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- -type datagoFrontendDB struct { +func GetDefaultDBConfig() GeneratorDBConfig { + return GeneratorDBConfig{ + Sources: "", + RequireImages: true, + RequireEmbeddings: false, + Tags: []string{}, + TagsNE: []string{}, + HasAttributes: []string{}, + LacksAttributes: []string{}, + HasMasks: []string{}, + LacksMasks: []string{}, + HasLatents: []string{}, + LacksLatents: []string{}, + Rank: 0, + WorldSize: 0, + ConcurrentDownloads: 1, + PageSize: 1000, + } +} + +type datagoGeneratorDB struct { baseRequest http.Request + config GeneratorDBConfig } -func newDatagoFrontendDB(config DatagoConfig) datagoFrontendDB { - // Define the base request once and for all +func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB { request := config.getDbRequest() api_key := os.Getenv("DATAROOM_API_KEY") @@ -118,10 +163,22 @@ func newDatagoFrontendDB(config DatagoConfig) datagoFrontendDB { fmt.Println("Dataroom API URL:", api_url) fmt.Println("Dataroom API KEY last characters:", getLast5Chars(api_key)) - return datagoFrontendDB{baseRequest: *getHTTPRequest(api_url, api_key, request)} + generatorDBConfig := GeneratorDBConfig{ + RequireImages: config.RequireImages, + RequireEmbeddings: config.RequireEmbeddings, + HasMasks: config.HasMasks, + LacksMasks: config.LacksMasks, + HasLatents: config.HasLatents, + LacksLatents: config.LacksLatents, + Sources: config.Sources, + Rank: config.Rank, + WorldSize: config.WorldSize, + } + + return datagoGeneratorDB{baseRequest: *getHTTPRequest(api_url, api_key, request), config: generatorDBConfig} } -func (f datagoFrontendDB) collectPages(ctx context.Context, chanPageResults chan dbResponse) { +func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pages) { // Fetch pages from the API, and feed the results to the items channel // This is meant to be run in a goroutine http_client := http.Client{Timeout: 30 * time.Second} @@ -152,6 +209,7 @@ func (f datagoFrontendDB) collectPages(ctx context.Context, chanPageResults chan if err = json.Unmarshal(body, &data); err != nil { return nil, err } + return &data, nil } @@ -176,13 +234,20 @@ func (f datagoFrontendDB) collectPages(ctx context.Context, chanPageResults chan // Commit the possible results to the downstream goroutines if len(data.DBSampleMetadata) > 0 { - chanPageResults <- *data + // TODO: There's probably a better way to do this + samplesDataPointers := make([]SampleDataPointers, len(data.DBSampleMetadata)) + + for i, sample := range data.DBSampleMetadata { + samplesDataPointers[i] = sample + } + + chanPages <- Pages{samplesDataPointers} } // Check if there are more pages to fetch if data.Next == "" { fmt.Println("No more pages to fetch, wrapping up") - close(chanPageResults) + close(chanPages) return } @@ -200,7 +265,7 @@ func (f datagoFrontendDB) collectPages(ctx context.Context, chanPageResults chan // Check if we consumed all the retries if !valid_page { fmt.Println("Too many errors fetching new pages, wrapping up") - close(chanPageResults) + close(chanPages) return } } diff --git a/src/pkg/generator_filesystem.go b/src/pkg/generator_filesystem.go new file mode 100644 index 0000000..e108093 --- /dev/null +++ b/src/pkg/generator_filesystem.go @@ -0,0 +1,76 @@ +package datago + +import ( + "context" + "fmt" + "os" + "path/filepath" +) + +// Walk over a local directory and return the list of files +// Note that we'll page this, so that file loading can start before the full list is available + +// --- File system walk structures --------------------------------------------------------------------------------------------------------------------------------------------------------------- +type fsSampleMetadata struct { + filePath string + fileName string +} + +// -- Define the front end goroutine --------------------------------------------------------------------------------------------------------------------------------------------------------------- +type datagoGeneratorFileSystem struct { + root_directory string + extensions set + page_size int +} + +type GeneratorFileSystemConfig struct { + RootPath string + PageSize int +} + +func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem { + supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"} + var extensionsMap = make(set) + for _, ext := range supported_img_extensions { + extensionsMap.Add(ext) + } + fmt.Println("File system root directory", config.RootPath) + fmt.Println("Supported image extensions", supported_img_extensions) + + return datagoGeneratorFileSystem{root_directory: config.RootPath, extensions: extensionsMap, page_size: config.PageSize} +} + +func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages chan Pages) { + // Walk over the directory and feed the results to the items channel + // This is meant to be run in a goroutine + + var samples []SampleDataPointers + + err := filepath.Walk(f.root_directory, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && f.extensions.Contains(filepath.Ext(path)) { + new_sample := fsSampleMetadata{filePath: path, fileName: info.Name()} + samples = append(samples, SampleDataPointers(new_sample)) + } + + // Check if we have enough files to send a page + if len(samples) >= f.page_size { + chanPages <- Pages{samples} + samples = nil + } + return nil + }) + + if err != nil { + fmt.Println("Error walking the path", f.root_directory) + } else { + // Send the last page + if len(samples) > 0 { + chanPages <- Pages{samples} + } + } + + close(chanPages) +} diff --git a/src/pkg/client/serdes.go b/src/pkg/serdes.go similarity index 93% rename from src/pkg/client/serdes.go rename to src/pkg/serdes.go index 5c22ce1..62288da 100644 --- a/src/pkg/client/serdes.go +++ b/src/pkg/serdes.go @@ -123,7 +123,7 @@ func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspect_ratio fl return &img_payload, aspect_ratio, nil } -func fetchURL(client *http.Client, url string, retries int) (URLPayload, error) { +func fetchURL(client *http.Client, url string, retries int) (urlPayload, error) { // Helper to fetch a binary payload from a URL err_msg := "" @@ -144,10 +144,10 @@ func fetchURL(client *http.Client, url string, retries int) (URLPayload, error) continue } - return URLPayload{url: url, content: body_bytes}, nil + return urlPayload{url: url, content: body_bytes}, nil } - return URLPayload{url: url, content: nil}, fmt.Errorf(err_msg) + return urlPayload{url: url, content: nil}, fmt.Errorf(err_msg) } func fetchImage(client *http.Client, url string, retries int, transform *ARAwareTransform, aspect_ratio float64, pre_encode_image bool, is_mask bool) (*ImagePayload, float64, error) { @@ -181,7 +181,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware return nil, -1., err_report } -func fetchSample(config *DatagoConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform) *Sample { +func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample { // Per sample work: // - fetch the raw payloads // - deserialize / decode, depending on the types @@ -194,7 +194,7 @@ func fetchSample(config *DatagoConfig, http_client *http.Client, sample_result d // Base image if config.RequireImages { - base_image, new_aspect_ratio, err := fetchImage(http_client, sample_result.ImageDirectURL, retries, transform, aspect_ratio, config.PreEncodeImages, false) + base_image, new_aspect_ratio, err := fetchImage(http_client, sample_result.ImageDirectURL, retries, transform, aspect_ratio, pre_encode_image, false) if err != nil { fmt.Println("Error fetching image:", sample_result.Id) @@ -213,7 +213,7 @@ func fetchSample(config *DatagoConfig, http_client *http.Client, sample_result d for _, latent := range sample_result.Latents { if strings.Contains(latent.LatentType, "image") && !strings.Contains(latent.LatentType, "latent_") { // Image types, registered as latents but they need to be jpg-decoded - new_image, _, err := fetchImage(http_client, latent.URL, retries, transform, aspect_ratio, config.PreEncodeImages, false) + new_image, _, err := fetchImage(http_client, latent.URL, retries, transform, aspect_ratio, pre_encode_image, false) if err != nil { fmt.Println("Error fetching masked image:", sample_result.Id, latent.LatentType) return nil @@ -222,7 +222,7 @@ func fetchSample(config *DatagoConfig, http_client *http.Client, sample_result d additional_images[latent.LatentType] = *new_image } else if latent.IsMask { // Mask types, registered as latents but they need to be png-decoded - mask_ptr, _, err := fetchImage(http_client, latent.URL, retries, transform, aspect_ratio, config.PreEncodeImages, true) + mask_ptr, _, err := fetchImage(http_client, latent.URL, retries, transform, aspect_ratio, pre_encode_image, true) if err != nil { fmt.Println("Error fetching mask:", sample_result.Id, latent.LatentType) return nil diff --git a/src/pkg/set.go b/src/pkg/set.go new file mode 100644 index 0000000..c27f9e2 --- /dev/null +++ b/src/pkg/set.go @@ -0,0 +1,25 @@ +package datago + +// Define a set type using a map with empty struct values +type set map[string]struct{} + +// Add an element to the set +func (s set) Add(value string) { + s[value] = struct{}{} +} + +// Remove an element from the set +func (s set) Remove(value string) { + delete(s, value) +} + +// Check if the set contains an element +func (s set) Contains(value string) bool { + _, exists := s[value] + return exists +} + +// Get the size of the set +func (s set) Size() int { + return len(s) +} diff --git a/src/pkg/client/transforms.go b/src/pkg/transforms.go similarity index 100% rename from src/pkg/client/transforms.go rename to src/pkg/transforms.go diff --git a/src/pkg/utils.go b/src/pkg/utils.go new file mode 100644 index 0000000..2e32f7f --- /dev/null +++ b/src/pkg/utils.go @@ -0,0 +1,25 @@ +package datago + +import ( + "unsafe" +) + +func dataPtrFromSlice(a []uint8) uintptr { + if len(a) == 0 { + return 0 + } + return uintptr(unsafe.Pointer(&a[0])) +} + +func getLast5Chars(s string) string { + runes := []rune(s) + if len(runes) <= 5 { + return s + } + return string(runes[len(runes)-5:]) +} + +func consumeChannel[T any](ch <-chan T) { + for range ch { + } +} diff --git a/src/pkg/worker_filesystem.go b/src/pkg/worker_filesystem.go new file mode 100644 index 0000000..f11bce0 --- /dev/null +++ b/src/pkg/worker_filesystem.go @@ -0,0 +1,67 @@ +package datago + +import ( + "fmt" + "os" +) + +type BackendFileSystem struct { + config *DatagoConfig +} + +func loadSample(config *DatagoConfig, filesystem_sample fsSampleMetadata, transform *ARAwareTransform, _pre_encode_images bool) *Sample { + // Load the file into []bytes + bytes_buffer, err := os.ReadFile(filesystem_sample.filePath) + if err != nil { + fmt.Println("Error reading file:", filesystem_sample.filePath) + return nil + } + + img_payload, _, err := imageFromBuffer(bytes_buffer, transform, -1., config.PreEncodeImages, false) + if err != nil { + fmt.Println("Error loading image:", filesystem_sample.fileName) + return nil + } + + return &Sample{ID: filesystem_sample.fileName, + Image: *img_payload, + } +} + +func (b BackendFileSystem) collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, pre_encode_images bool) { + + ack_channel := make(chan bool) + + sampleWorker := func() { + for { + item_to_fetch, open := <-chanSampleMetadata + if !open { + ack_channel <- true + return + } + + // Cast the item to fetch to the correct type + filesystem_sample, ok := item_to_fetch.(fsSampleMetadata) + if !ok { + panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured") + } + + sample := loadSample(b.config, filesystem_sample, transform, pre_encode_images) + if sample != nil { + chanSamples <- *sample + } + } + } + + // Start the workers and work on the metadata channel + for i := 0; i < b.config.ConcurrentDownloads; i++ { + go sampleWorker() + } + + // Wait for all the workers to be done or overall context to be cancelled + for i := 0; i < b.config.ConcurrentDownloads; i++ { + <-ack_channel + } + close(chanSamples) + fmt.Println("No more items to serve, wrapping up") +} diff --git a/src/pkg/client/backend_http.go b/src/pkg/worker_http.go similarity index 61% rename from src/pkg/client/backend_http.go rename to src/pkg/worker_http.go index ca808eb..cfec108 100644 --- a/src/pkg/client/backend_http.go +++ b/src/pkg/worker_http.go @@ -7,10 +7,10 @@ import ( ) type BackendHTTP struct { - config *DatagoConfig + config *GeneratorDBConfig } -func (b BackendHTTP) collectSamples(chanSampleMetadata chan dbSampleMetadata, chanSamples chan Sample, transform *ARAwareTransform) { +func (b BackendHTTP) collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, pre_encode_images bool) { ack_channel := make(chan bool) @@ -25,7 +25,13 @@ func (b BackendHTTP) collectSamples(chanSampleMetadata chan dbSampleMetadata, ch return } - sample := fetchSample(b.config, &http_client, item_to_fetch, transform) + // Cast the item to fetch to the correct type + http_sample, ok := item_to_fetch.(dbSampleMetadata) + if !ok { + panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured") + } + + sample := fetchSample(b.config, &http_client, http_sample, transform, pre_encode_images) if sample != nil { chanSamples <- *sample } diff --git a/src/tests/client_test.go b/src/tests/client_test.go index 7774cbf..30acf96 100644 --- a/src/tests/client_test.go +++ b/src/tests/client_test.go @@ -4,7 +4,7 @@ import ( "os" "testing" - datago "datago/pkg/client" + datago "datago/pkg" "github.com/davidbyttow/govips/v2/vips" ) @@ -13,10 +13,17 @@ func get_test_source() string { return os.Getenv("DATAROOM_TEST_SOURCE") } -func TestClientStartStop(t *testing.T) { +func get_default_test_config() datago.DatagoConfig { config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.PageSize = 32 + db_config := datago.GetDefaultDBConfig() + db_config.Sources = get_test_source() + db_config.PageSize = 32 + config.SourceConfig = db_config + return config +} + +func TestClientStartStop(t *testing.T) { + config := get_default_test_config() // Check that we can start, do nothing and stop the client immediately client := datago.GetClient(config) @@ -25,9 +32,7 @@ func TestClientStartStop(t *testing.T) { } func TestClientNoStart(t *testing.T) { - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.PageSize = 32 + config := get_default_test_config() // Check that we can get a sample without starting the client client := datago.GetClient(config) @@ -40,9 +45,7 @@ func TestClientNoStart(t *testing.T) { func TestClientNoStop(t *testing.T) { // Check that we can start, get a sample, and destroy the client immediately // In that case Stop() should be called in the background, and everything should work just fine - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.PageSize = 32 + config := get_default_test_config() config.SamplesBufferSize = 1 client := datago.GetClient(config) @@ -54,9 +57,7 @@ func TestClientNoStop(t *testing.T) { func TestMoreThanBufferSize(t *testing.T) { // Check that we can start, get a sample, and destroy the client immediately // In that case Stop() should be called in the background, and everything should work just fine - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.PageSize = 32 + config := get_default_test_config() config.SamplesBufferSize = 1 client := datago.GetClient(config) @@ -69,11 +70,11 @@ func TestMoreThanBufferSize(t *testing.T) { } func TestFetchImage(t *testing.T) { - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.RequireImages = true - config.PageSize = 32 + config := get_default_test_config() config.SamplesBufferSize = 1 + db_config := config.SourceConfig.(datago.GeneratorDBConfig) + db_config.RequireImages = true + config.SourceConfig = db_config // Check that we can get an image client := datago.GetClient(config) @@ -98,13 +99,14 @@ func TestFetchImage(t *testing.T) { } func TestExtraFields(t *testing.T) { - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.RequireImages = true - config.PageSize = 32 - config.HasLatents = "masked_image" - config.HasMasks = "segmentation_mask" + config := get_default_test_config() config.SamplesBufferSize = 1 + db_config := config.SourceConfig.(datago.GeneratorDBConfig) + db_config.RequireImages = true + db_config.PageSize = 32 + db_config.HasLatents = []string{"masked_image"} + db_config.HasMasks = []string{"segmentation_mask"} + config.SourceConfig = db_config // Check that we can get an image client := datago.GetClient(config) @@ -129,11 +131,13 @@ func TestExtraFields(t *testing.T) { } func TestCropAndResize(t *testing.T) { - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.RequireImages = true - config.PageSize = 32 + config := get_default_test_config() + config.SamplesBufferSize = 1 config.CropAndResize = true + db_config := config.SourceConfig.(datago.GeneratorDBConfig) + db_config.RequireImages = true + db_config.PageSize = 32 + config.SourceConfig = db_config client := datago.GetClient(config) client.Start() @@ -169,14 +173,18 @@ func TestCropAndResize(t *testing.T) { func TestImageBufferCompression(t *testing.T) { // Check that the image buffer is compressed, and that we can decode it properly - config := datago.GetDefaultConfig() - config.Sources = get_test_source() - config.RequireImages = true - config.PageSize = 32 + config := get_default_test_config() + config.SamplesBufferSize = 1 config.CropAndResize = true - config.HasLatents = "masked_image" - config.HasMasks = "segmentation_mask" config.PreEncodeImages = true + + db_config := config.SourceConfig.(datago.GeneratorDBConfig) + db_config.RequireImages = true + db_config.PageSize = 32 + db_config.HasLatents = []string{"masked_image"} + db_config.HasMasks = []string{"segmentation_mask"} + config.SourceConfig = db_config + client := datago.GetClient(config) sample := client.GetSample()