diff --git a/auth/provider/provider.go b/auth/provider/provider.go index 607e6ded..f9230d27 100644 --- a/auth/provider/provider.go +++ b/auth/provider/provider.go @@ -76,7 +76,7 @@ func getKey(config JWTConfig) (interface{}, error) { publicKeyFile := config.PubKeyFile switch { case publicKeyFile != "": - kd, err := config.fs.ReadFile(publicKeyFile) + kd, err := config.fs.Get(publicKeyFile) if err != nil { return nil, err } diff --git a/conf/config.go b/conf/config.go index 885e0cc8..dbe3da49 100644 --- a/conf/config.go +++ b/conf/config.go @@ -51,7 +51,7 @@ func NewConfigWithFS(fs core.FS, configFile string) (c *core.Config, err error) func readConfig(fs core.FS, configFile string, v interface{}) (err error) { format := filepath.Ext(configFile) - b, err := fs.ReadFile(configFile) + b, err := fs.Get(configFile) if err != nil { return fmt.Errorf("error reading config: %w", err) } diff --git a/core/core.go b/core/core.go index 5671ceb8..5145e32b 100644 --- a/core/core.go +++ b/core/core.go @@ -86,7 +86,7 @@ func (gj *graphjin) initDiscover() (err error) { func (gj *graphjin) _initDiscover() (err error) { if gj.prod && gj.conf.EnableSchema { - b, err := gj.fs.ReadFile("db.graphql") + b, err := gj.fs.Get("db.graphql") if err != nil { return err } @@ -120,7 +120,7 @@ func (gj *graphjin) _initDiscover() (err error) { if err := writeSchema(gj.dbinfo, &buf); err != nil { return err } - err = gj.fs.CreateFile("db.graphql", buf.Bytes()) + err = gj.fs.Put("db.graphql", buf.Bytes()) if err != nil { return } @@ -178,7 +178,7 @@ func (gj *graphjin) _initSchema() (err error) { if err != nil { return } - err = gj.fs.CreateFile("intro.json", []byte(introJSON)) + err = gj.fs.Put("intro.json", []byte(introJSON)) if err != nil { return } diff --git a/core/internal/allow/allow.go b/core/internal/allow/allow.go index 070fea36..f83432b0 100644 --- a/core/internal/allow/allow.go +++ b/core/internal/allow/allow.go @@ -14,9 +14,8 @@ import ( ) type FS interface { - CreateDir(path string) error - CreateFile(path string, data []byte) error - ReadFile(path string) (data []byte, err error) + Get(path string) (data []byte, err error) + Put(path string, data []byte) (err error) Exists(path string) (exists bool, err error) } @@ -63,11 +62,6 @@ func New(log *_log.Logger, fs FS, readOnly bool) (al *List, err error) { } al.saveChan = make(chan Item) - err = fs.CreateDir(filepath.Join(queryPath, fragmentPath)) - if err != nil { - return - } - go func() { for { v, ok := <-al.saveChan @@ -145,7 +139,7 @@ func (al *List) get(queryPath, name, ext string, useCache bool) (item Item, err jsonFile := filepath.Join(queryPath, (name + ".json")) ok, err := al.fs.Exists(jsonFile) if ok { - vars, err = al.fs.ReadFile(jsonFile) + vars, err = al.fs.Get(jsonFile) } if err != nil { return @@ -204,7 +198,7 @@ func (al *List) saveItem(item Item) (err error) { } ff := filepath.Join(queryPath, "fragments", (fragFile + ".gql")) - err = al.fs.CreateFile(ff, []byte(f.Value)) + err = al.fs.Put(ff, []byte(f.Value)) if err != nil { return } @@ -215,7 +209,7 @@ func (al *List) saveItem(item Item) (err error) { buf.Write(bytes.TrimSpace(item.Query)) qf := filepath.Join(queryPath, (queryFile + ".gql")) - err = al.fs.CreateFile(qf, bytes.TrimSpace(buf.Bytes())) + err = al.fs.Put(qf, bytes.TrimSpace(buf.Bytes())) if err != nil { return } @@ -227,7 +221,7 @@ func (al *List) saveItem(item Item) (err error) { if err != nil { return } - err = al.fs.CreateFile(jf, vars) + err = al.fs.Put(jf, vars) } return } diff --git a/core/internal/allow/gql.go b/core/internal/allow/gql.go index 468d50a1..61dfa978 100644 --- a/core/internal/allow/gql.go +++ b/core/internal/allow/gql.go @@ -29,7 +29,7 @@ func readGQL(fs FS, fname string) (gql []byte, err error) { } func parseGQL(fs FS, fname string, r io.Writer) (err error) { - b, err := fs.ReadFile(fname) + b, err := fs.Get(fname) if err != nil { return err } diff --git a/core/osfs.go b/core/osfs.go index 4605cd1c..05125c0e 100644 --- a/core/osfs.go +++ b/core/osfs.go @@ -13,16 +13,23 @@ type osFS struct { func NewFS(basePath string) *osFS { return &osFS{bp: basePath} } -func (f *osFS) CreateDir(path string) error { - return os.MkdirAll(filepath.Join(f.bp, path), os.ModePerm) +func (f *osFS) Get(path string) ([]byte, error) { + return os.ReadFile(filepath.Join(f.bp, path)) } -func (f *osFS) CreateFile(path string, data []byte) error { - return os.WriteFile(filepath.Join(f.bp, path), data, os.ModePerm) -} +func (f *osFS) Put(path string, data []byte) (err error) { + path = filepath.Join(f.bp, path) -func (f *osFS) ReadFile(path string) ([]byte, error) { - return os.ReadFile(filepath.Join(f.bp, path)) + dir := filepath.Dir(path) + ok, err := f.Exists(dir) + if !ok { + err = os.MkdirAll(dir, os.ModePerm) + } + if err != nil { + return + } + + return os.WriteFile(path, data, os.ModePerm) } func (f *osFS) Exists(path string) (ok bool, err error) { diff --git a/core/plugin.go b/core/plugin.go index 6e008e93..4f34fdbd 100644 --- a/core/plugin.go +++ b/core/plugin.go @@ -1,8 +1,7 @@ package core type FS interface { - CreateDir(path string) error - CreateFile(path string, data []byte) error - ReadFile(path string) (data []byte, err error) + Get(path string) (data []byte, err error) + Put(path string, data []byte) error Exists(path string) (exists bool, err error) } diff --git a/plugin/afero/afero.go b/plugin/afero/afero.go index d273a0f5..c3e979a0 100644 --- a/plugin/afero/afero.go +++ b/plugin/afero/afero.go @@ -2,6 +2,7 @@ package afero import ( "os" + "path/filepath" "github.com/spf13/afero" ) @@ -14,16 +15,21 @@ func NewFS(fs afero.Fs, basePath string) *AferoFS { return &AferoFS{fs: afero.NewBasePathFs(fs, basePath)} } -func (f *AferoFS) CreateDir(path string) error { - return f.fs.MkdirAll(path, os.ModePerm) +func (f *AferoFS) Get(path string) ([]byte, error) { + return afero.ReadFile(f.fs, path) } -func (f *AferoFS) CreateFile(path string, data []byte) error { - return afero.WriteFile(f.fs, path, data, os.ModePerm) -} +func (f *AferoFS) Put(path string, data []byte) (err error) { + dir := filepath.Dir(path) + ok, err := f.Exists(dir) + if !ok { + err = f.fs.MkdirAll(dir, os.ModePerm) + } + if err != nil { + return + } -func (f *AferoFS) ReadFile(path string) ([]byte, error) { - return afero.ReadFile(f.fs, path) + return afero.WriteFile(f.fs, path, data, os.ModePerm) } func (f *AferoFS) Exists(path string) (exists bool, err error) { diff --git a/plugin/osfs/osfs.go b/plugin/osfs/osfs.go index fddc849a..c357db2e 100644 --- a/plugin/osfs/osfs.go +++ b/plugin/osfs/osfs.go @@ -13,16 +13,23 @@ type FS struct { func NewFS(basePath string) *FS { return &FS{bp: basePath} } -func (f *FS) CreateDir(path string) error { - return os.MkdirAll(filepath.Join(f.bp, path), os.ModePerm) +func (f *FS) Get(path string) ([]byte, error) { + return os.ReadFile(filepath.Join(f.bp, path)) } -func (f *FS) CreateFile(path string, data []byte) error { - return os.WriteFile(filepath.Join(f.bp, path), data, os.ModePerm) -} +func (f *FS) Put(path string, data []byte) (err error) { + path = filepath.Join(f.bp, path) -func (f *FS) ReadFile(path string) ([]byte, error) { - return os.ReadFile(filepath.Join(f.bp, path)) + dir := filepath.Dir(path) + ok, err := f.Exists(dir) + if !ok { + err = os.MkdirAll(dir, os.ModePerm) + } + if err != nil { + return + } + + return os.WriteFile(path, data, os.ModePerm) } func (f *FS) Exists(path string) (ok bool, err error) { diff --git a/serv/db.go b/serv/db.go index 5c4c74dd..3f892993 100644 --- a/serv/db.go +++ b/serv/db.go @@ -117,7 +117,7 @@ func initPostgres(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, if strings.Contains(c.DB.ServerCert, pemSig) { pem = []byte(strings.ReplaceAll(c.DB.ServerCert, `\n`, "\n")) } else { - pem, err = fs.ReadFile(c.DB.ServerCert) + pem, err = fs.Get(c.DB.ServerCert) } if err != nil { @@ -177,11 +177,11 @@ func initMysql(conf *Config, openDB, useTelemetry bool, fs core.FS) (*dbConf, er func loadX509KeyPair(fs core.FS, certFile, keyFile string) ( cert tls.Certificate, err error, ) { - certPEMBlock, err := fs.ReadFile(certFile) + certPEMBlock, err := fs.Get(certFile) if err != nil { return cert, err } - keyPEMBlock, err := fs.ReadFile(keyFile) + keyPEMBlock, err := fs.Get(keyFile) if err != nil { return cert, err } diff --git a/tests/core_test.go b/tests/core_test.go index 61f68890..5f68b896 100644 --- a/tests/core_test.go +++ b/tests/core_test.go @@ -27,8 +27,8 @@ func TestReadInConfigWithEnvVars(t *testing.T) { defer os.RemoveAll(dir) fs := osfs.NewFS(dir) - fs.CreateFile("dev.yml", []byte(devConfig)) - fs.CreateFile("prod.yml", []byte(prodConfig)) + fs.Put("dev.yml", []byte(devConfig)) + fs.Put("prod.yml", []byte(prodConfig)) c, err := conf.NewConfigWithFS(fs, "dev.yml") assert.NoError(t, err) @@ -118,11 +118,7 @@ func TestAllowList(t *testing.T) { defer os.RemoveAll(dir) fs := osfs.NewFS(dir) - if err := fs.CreateDir("queries"); err != nil { - t.Error(err) - return - } - err = fs.CreateFile("queries/getProducts.gql", []byte(gql1)) + err = fs.Put("queries/getProducts.gql", []byte(gql1)) if err != nil { t.Error(err) return diff --git a/tests/insert_test.go b/tests/insert_test.go index 64b6ff16..d6ceb04b 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -706,7 +706,6 @@ func TestAllowListWithMutations(t *testing.T) { defer os.RemoveAll(dir) fs := osfs.NewFS(dir) - err = fs.CreateDir("queries") assert.NoError(t, err) conf1 := newConfig(&core.Config{DBType: dbType, DisableAllowList: false}) diff --git a/tests/intro_test.go b/tests/intro_test.go index f6a6666b..288643c1 100644 --- a/tests/intro_test.go +++ b/tests/intro_test.go @@ -55,7 +55,7 @@ func TestIntrospection(t *testing.T) { if err != nil { panic(err) } - b, err := fs.ReadFile("intro.json") + b, err := fs.Get("intro.json") assert.NoError(t, err) assert.NotEmpty(t, b) } diff --git a/wasm/fs.go b/wasm/fs.go index b19d95cf..33ce501d 100644 --- a/wasm/fs.go +++ b/wasm/fs.go @@ -14,22 +14,24 @@ type JSFS struct { bp string } -func NewJSFS(fs js.Value) *JSFS { return &JSFS{fs: fs} } -func NewJSFSWithBase(fs js.Value, path string) *JSFS { return &JSFS{fs: fs, bp: path} } +func NewJSFS(fs js.Value, path string) *JSFS { return &JSFS{fs: fs, bp: path} } -func (f *JSFS) CreateDir(path string) (err error) { +func (f *JSFS) Get(path string) (data []byte, err error) { + path = filepath.Join(f.bp, path) defer func() { if err1 := recover(); err1 != nil { err = toError(err1) } }() - opts := map[string]interface{}{"recursive": true} - path = filepath.Join(f.bp, path) - f.fs.Call("mkdirSync", path, opts) - return nil + buf := f.fs.Call("readFileSync", path) + + a := js.Global().Get("Uint8Array").New(buf) + data = make([]byte, a.Get("length").Int()) + js.CopyBytesToGo(data, a) + return data, nil } -func (f *JSFS) CreateFile(path string, data []byte) (err error) { +func (f *JSFS) Put(path string, data []byte) (err error) { defer func() { if err1 := recover(); err1 != nil { err = toError(err1) @@ -37,6 +39,15 @@ func (f *JSFS) CreateFile(path string, data []byte) (err error) { }() path = filepath.Join(f.bp, path) + dir := filepath.Dir(path) + ok, err := f.Exists(dir) + if !ok { + err = f.createDir(dir) + } + if err != nil { + return + } + a := js.Global().Get("Uint8Array").New(len(data)) js.CopyBytesToJS(a, data) runtime.KeepAlive(data) @@ -49,21 +60,6 @@ func (f *JSFS) CreateFile(path string, data []byte) (err error) { return nil } -func (f *JSFS) ReadFile(path string) (data []byte, err error) { - path = filepath.Join(f.bp, path) - defer func() { - if err1 := recover(); err1 != nil { - err = toError(err1) - } - }() - buf := f.fs.Call("readFileSync", path) - - a := js.Global().Get("Uint8Array").New(buf) - data = make([]byte, a.Get("length").Int()) - js.CopyBytesToGo(data, a) - return data, nil -} - func (f *JSFS) Exists(path string) (exists bool, err error) { defer func() { if err1 := recover(); err1 != nil { @@ -79,15 +75,14 @@ func (f *JSFS) Exists(path string) (exists bool, err error) { return } -type FileInfo struct { - name string - isDir bool -} - -func (fi *FileInfo) Name() string { - return fi.name -} - -func (fi *FileInfo) IsDir() bool { - return fi.isDir +func (f *JSFS) createDir(path string) (err error) { + defer func() { + if err1 := recover(); err1 != nil { + err = toError(err1) + } + }() + opts := map[string]interface{}{"recursive": true} + path = filepath.Join(f.bp, path) + f.fs.Call("mkdirSync", path, opts) + return nil } diff --git a/wasm/main.go b/wasm/main.go index 3a787d93..f121cb81 100644 --- a/wasm/main.go +++ b/wasm/main.go @@ -52,7 +52,7 @@ func graphjinFunc() js.Func { return toJSError(err) } - fs := NewJSFSWithBase(fsv, cpv.String()) + fs := NewJSFS(fsv, cpv.String()) confVal := cov.Get("value").String() confValIsFile := cov.Get("isFile").Bool()