Skip to content

Commit

Permalink
WIP, would need more work but short on time
Browse files Browse the repository at this point in the history
[x] initial refactor
[x] adding a barebones filesystem dataloader
[x] barebones unit test -> broken
[ ] benchmark on IN1k
  • Loading branch information
blefaudeux committed Oct 22, 2024
1 parent dd80544 commit d2ab48c
Show file tree
Hide file tree
Showing 19 changed files with 472 additions and 281 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
sudo apt-get -y install libvips-dev
- name: Build
run: cd src/cmd/main && go build -v main.go
run: cd src/cmd && go build -v main.go

- name: Test
env:
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/gopy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,17 @@ jobs:
- name: Build python module
run: |
cd src/pkg/client
cd src/pkg
gopy pkg -author="Photoroom" -email="[email protected]" -name="datago" .
export DESTINATION="../../../build"
export DESTINATION="../../build"
mkdir -p $DESTINATION/datago
mv datago/* $DESTINATION/datago/.
mv setup.py $DESTINATION/.
mv Makefile $DESTINATION/.
mv README.md $DESTINATION/.
rm LICENSE MANIFEST.in
cd ../../../build
ls ../..
cd ../../build
- name: Install python module
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ NOTE:
- - Either `export PATH=$PATH:~/go/bin` or add it to your .bashrc
- you may need this to make sure that LDD looks at the current folder `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:.`

then from the /pkg/client folder:
then from the /pkg folder:

```bash
$ gopy pkg -author="Photoroom" -email="[email protected]" -url="" -name="datago" -version="0.0.1" .
Expand Down
4 changes: 2 additions & 2 deletions generate_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ DESTINATION="../../../python_$python_version"
rm -rf $DESTINATION

# Build the python package via the gopy toolchain
cd src/pkg/client
cd src/pkg
gopy pkg -author="Photoroom" -email="[email protected]" -url="" -name="datago" -version="0.3" .
mkdir -p $DESTINATION/datago
mv datago/* $DESTINATION/datago/.
Expand All @@ -21,6 +21,6 @@ mv README.md $DESTINATION/.
rm LICENSE
rm MANIFEST.in

cd ../../..
cd ../..


24 changes: 23 additions & 1 deletion python_tests/datago_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datago import datago
import pytest
import os
from PIL import Image


def get_test_source():
return os.getenv("DATAROOM_TEST_SOURCE")


def test_get_sample():
def test_get_sample_db():
# Check that we can instantiate a client and get a sample, nothing more
config = datago.GetDefaultConfig()
config.source = get_test_source()
Expand All @@ -17,6 +18,27 @@ def test_get_sample():
assert data.ID != ""


def test_get_sample_filesystem():
cwd = os.getcwd()

try:
# Dump a sample image to the filesystem
img = Image.new("RGB", (100, 100))
img.save(cwd + "/test.png")

# Check that we can instantiate a client and get a sample, nothing more
config = datago.GetDefaultConfig()
config.SourceType = "filesystem"
config.Sources = cwd
config.sample = 1

client = datago.GetClient(config)
data = client.GetSample()
assert data.ID != ""
finally:
os.remove(cwd + "/test.png")


# TODO: Backport all the image correctness tests

if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytest
pytest
pillow
22 changes: 5 additions & 17 deletions src/cmd/main/main.go → src/cmd/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package main

import (
datago "datago/pkg/client"
datago "datago/pkg"
"flag"
"fmt"
"os"
Expand All @@ -13,6 +13,8 @@ import (
func main() {
// Define flags
client_config := datago.GetDefaultConfig()
client_config.SourceType = datago.SourceTypeFileSystem
client_config.SourceConfig = datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM"), PageSize: 10}
client_config.DefaultImageSize = 1024
client_config.DownsamplingRatio = 32

Expand All @@ -21,20 +23,6 @@ func main() {
client_config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
client_config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served")

client_config.Sources = *flag.String("source", "GETTY", "The source for the items")
client_config.RequireImages = *flag.Bool("require_images", true, "Whether the items require images")
client_config.RequireEmbeddings = *flag.Bool("require_embeddings", false, "Whether the items require the DB embeddings")

client_config.Tags = *flag.String("tags", "", "The tags to filter for")
client_config.TagsNE = *flag.String("tags__ne", "", "The tags that the samples should not have")
client_config.HasMasks = *flag.String("has_masks", "", "The masks to filter for")
client_config.HasLatents = *flag.String("has_latents", "", "The masks to filter for")
client_config.HasAttributes = *flag.String("has_attributes", "", "The attributes to filter for")

client_config.LacksMasks = *flag.String("lacks_masks", "", "The masks to filter against")
client_config.LacksLatents = *flag.String("lacks_latents", "", "The masks to filter against")
client_config.LacksAttributes = *flag.String("lacks_attributes", "", "The attributes to filter against")

limit := flag.Int("limit", 2000, "The number of items to fetch")
profile := flag.Bool("profile", false, "Whether to profile the code")

Expand Down Expand Up @@ -71,7 +59,7 @@ func main() {
for i := 0; i < *limit; i++ {
sample := dataroom_client.GetSample()
if sample.ID == "" {
fmt.Printf("Error fetching sample")
fmt.Println("No more samples ", i, " samples served")
break
}
}
Expand All @@ -83,5 +71,5 @@ func main() {
elapsedTime := time.Since(startTime)
fps := float64(*limit) / elapsedTime.Seconds()
fmt.Printf("Total execution time: %.2f \n", elapsedTime.Seconds())
fmt.Printf("Average fetch rate: %.2f fetches per second\n", fps)
fmt.Printf("Average throughput: %.2f samples per second\n", fps)
}
52 changes: 52 additions & 0 deletions src/pkg/architecture.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package datago

import "context"

// --- Sample data structures - these will be exposed to the Python world ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type LatentPayload struct {
Data []byte
Len int
DataPtr uintptr
}

type ImagePayload struct {
Data []byte
OriginalHeight int // Good indicator of the image frequency dbResponse at the current resolution
OriginalWidth int
Height int // Useful to decode the current payload
Width int
Channels int
DataPtr uintptr
}

type Sample struct {
ID string
Source string
Attributes map[string]interface{}
Image ImagePayload
Masks map[string]ImagePayload
AdditionalImages map[string]ImagePayload
Latents map[string]LatentPayload
CocaEmbedding []float32
Tags []string
}

// --- Generator and Backend interfaces ---------------------------------------------------------------------------------------------------------------------------------------------------------------

// The generator will be responsible for producing pages of metadata which can be dispatched
// to the dispatch goroutine. The metadata will be used to fetch the actual payloads

type SampleDataPointers interface{}

type Pages struct {
samplesDataPointers []SampleDataPointers
}

type Generator interface {
generatePages(ctx context.Context, chanPages chan Pages)
}

// The backend will be responsible for fetching the payloads and deserializing them
type Backend interface {
collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, pre_encode_images bool)
}
Loading

0 comments on commit d2ab48c

Please sign in to comment.