Skip to content

Commit

Permalink
Do not stat shared resources when downloading (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishank011 authored Aug 3, 2020
1 parent 29fbcb8 commit 83c5528
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 32 deletions.
9 changes: 9 additions & 0 deletions changelog/unreleased/download-shares-fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Bugfix: Do not stat shared resources when downloading

Previously, we statted the resources in all download requests resulting in
failures when downloading references. This PR fixes that by statting only in
case the resource is not present in the shares folder. It also fixes a bug where
we allowed uploading to the mount path, resulting in overwriting the user home
directory.

https://github.com/cs3org/reva/pull/1038
7 changes: 7 additions & 0 deletions examples/storage-references/storage-reva.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ address = "0.0.0.0:18000"
driver = "local"
mount_path = "/reva"
mount_id = "123e4567-e89b-12d3-a456-426655440000"
data_server_url = "http://localhost:18001/data"

[http]
address = "0.0.0.0:18001"

[http.services.dataprovider]
driver = "local"
40 changes: 20 additions & 20 deletions internal/grpc/services/gateway/storageprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,25 +104,6 @@ func (s *svc) getHome(ctx context.Context) string {
return "/home"
}
func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFileDownloadRequest) (*gateway.InitiateFileDownloadResponse, error) {
statReq := &provider.StatRequest{Ref: req.Ref}
statRes, err := s.stat(ctx, statReq)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref:"+req.Ref.String()),
}, nil
}
if statRes.Status.Code != rpc.Code_CODE_OK {
if statRes.Status.Code == rpc.Code_CODE_NOT_FOUND {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewNotFound(ctx, "gateway: file not found"),
}, nil
}
err := status.NewErrorFromCode(statRes.Status.Code, "gateway")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref"),
}, nil
}

p, err := s.getPath(ctx, req.Ref)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Expand All @@ -131,13 +112,31 @@ func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFi
}

if !s.inSharedFolder(ctx, p) {
statReq := &provider.StatRequest{Ref: req.Ref}
statRes, err := s.stat(ctx, statReq)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref:"+req.Ref.String()),
}, nil
}
if statRes.Status.Code != rpc.Code_CODE_OK {
if statRes.Status.Code == rpc.Code_CODE_NOT_FOUND {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewNotFound(ctx, "gateway: file not found"),
}, nil
}
err := status.NewErrorFromCode(statRes.Status.Code, "gateway")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref"),
}, nil
}
return s.initiateFileDownload(ctx, req)
}

log := appctx.GetLogger(ctx)
if s.isSharedFolder(ctx, p) || s.isShareName(ctx, p) {
log.Debug().Msgf("path:%s points to shared folder or share name", p)
err := errtypes.PermissionDenied("gateway: cannot upload to share folder or share name: path=" + p)
err := errtypes.PermissionDenied("gateway: cannot download share folder or share name: path=" + p)
log.Err(err).Msg("gateway: error downloading")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInvalidArg(ctx, "path points to share folder or share name"),
Expand Down Expand Up @@ -194,6 +193,7 @@ func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFi
},
}
req.Ref = ref
log.Debug().Msg("download path: " + target)
return s.initiateFileDownload(ctx, req)
}

Expand Down
5 changes: 5 additions & 0 deletions internal/grpc/services/storageprovider/storageprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ func (s *service) InitiateFileUpload(ctx context.Context, req *provider.Initiate
Status: status.NewInternal(ctx, err, "error unwrapping path"),
}, nil
}
if newRef.GetPath() == "/" {
return &provider.InitiateFileUploadResponse{
Status: status.NewInternal(ctx, errors.New("can't upload to mount path"), ""),
}, nil
}
url := *s.dataServerURL
if s.conf.DisableTus {
url.Path = path.Join("/", url.Path, newRef.GetPath())
Expand Down
27 changes: 15 additions & 12 deletions pkg/user/manager/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,26 +353,24 @@ func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*use

}

func (m *manager) findUsersByFilter(ctx context.Context, url string) ([]*userpb.User, error) {
func (m *manager) findUsersByFilter(ctx context.Context, url string, users map[string]*userpb.User) error {

userData, err := m.sendAPIRequest(ctx, url)
if err != nil {
return nil, err
return err
}

users := []*userpb.User{}

for _, usr := range userData {
usrInfo, ok := usr.(map[string]interface{})
if !ok {
return nil, errors.New("rest: error in type assertion")
return errors.New("rest: error in type assertion")
}

uid := &userpb.UserId{
OpaqueId: usrInfo["upn"].(string),
Idp: m.conf.IDProvider,
}
users = append(users, &userpb.User{
users[uid.OpaqueId] = &userpb.User{
Id: uid,
Username: usrInfo["upn"].(string),
Mail: usrInfo["primaryAccountEmail"].(string),
Expand All @@ -389,10 +387,10 @@ func (m *manager) findUsersByFilter(ctx context.Context, url string) ([]*userpb.
},
},
},
})
}
}

return users, nil
return nil
}

func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User, error) {
Expand All @@ -407,18 +405,23 @@ func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User,
return nil, errors.New("rest: illegal characters present in query")
}

users := []*userpb.User{}
users := make(map[string]*userpb.User)

for _, f := range filters {
url := fmt.Sprintf("%s/Identity/?filter=%s:contains:%s&field=id&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid",
m.conf.APIBaseURL, f, query)
filteredUsers, err := m.findUsersByFilter(ctx, url)
err := m.findUsersByFilter(ctx, url, users)
if err != nil {
return nil, err
}
users = append(users, filteredUsers...)
}
return users, nil

userSlice := make([]*userpb.User, len(users))
for _, v := range users {
userSlice = append(userSlice, v)
}

return userSlice, nil
}

func (m *manager) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error) {
Expand Down

0 comments on commit 83c5528

Please sign in to comment.