diff --git a/multipart.go b/multipart.go index d7322a7..e5e4ed8 100644 --- a/multipart.go +++ b/multipart.go @@ -58,7 +58,6 @@ func (mp *S3Multipart) AddPart(r io.Reader, size int64, md5sum []byte) error { } req.Header.Set("Content-Length", fmt.Sprintf("%d", size)) - req.Header.Set("Host", req.URL.Host) req.Header.Set("Content-Type", "application/octet-stream") req.ContentLength = size @@ -114,7 +113,6 @@ func (mp *S3Multipart) Complete(contentType string) error { } req.Header.Set("Content-Length", fmt.Sprintf("%d", len(xmlBody))) - req.Header.Set("Host", req.URL.Host) req.Header.Set("Content-Type", contentType) req.ContentLength = int64(len(xmlBody)) diff --git a/s3.go b/s3.go index a32bdb9..f95d24f 100644 --- a/s3.go +++ b/s3.go @@ -2,9 +2,7 @@ package s3 import ( "bytes" - "crypto/hmac" "crypto/md5" - "crypto/sha1" "encoding/base64" "encoding/xml" "fmt" @@ -14,9 +12,12 @@ import ( "net/http" "net/url" "runtime" - "sort" "strings" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/signer/v4" ) // S3 provides a wrapper around your S3 credentials. It carries no other internal state @@ -38,66 +39,32 @@ func NewS3(bucket, accessId, secret string) *S3 { } } -func (s3 *S3) signRequest(req *http.Request) { - amzHeaders := "" - resourceUrl, _ := url.Parse("/" + s3.bucket + req.URL.Path) - resource := resourceUrl.String() - - /* Ugh, AWS requires us to order the parameters in a specific ordering for - * signing. Makes sense, but is annoying because a map does not have a defined - * ordering (and basically returns elements in a random order) -- so we have - * to sort by hand */ - query := req.URL.Query() - if len(query) > 0 { - keys := []string{} - - for k := range query { - keys = append(keys, k) - } - - sort.Strings(keys) - - parts := []string{} - - for _, key := range keys { - vals := query[key] - - for _, val := range vals { - if val == "" { - parts = append(parts, url.QueryEscape(key)) - - } else { - part := fmt.Sprintf("%s=%s", url.QueryEscape(key), url.QueryEscape(val)) - parts = append(parts, part) - } - } - } - - req.URL.RawQuery = strings.Join(parts, "&") - } - - if req.URL.RawQuery != "" { - resource += "?" + req.URL.RawQuery +func (s4 *S3) signRequest(req *http.Request) (er error) { + signer := v4.Signer{ + Credentials: credentials.NewStaticCredentials(s4.accessId, s4.secret, ""), + DisableURIPathEscaping: true, } + t := time.Now() if req.Header.Get("Date") == "" { - req.Header.Set("Date", time.Now().Format(time.RFC1123)) + req.Header.Set("Date", t.Format(time.RFC1123)) } - authStr := strings.Join([]string{ - strings.TrimSpace(req.Method), - req.Header.Get("Content-MD5"), - req.Header.Get("Content-Type"), - req.Header.Get("Date"), - amzHeaders + resource, - }, "\n") - - h := hmac.New(sha1.New, []byte(s3.secret)) - h.Write([]byte(authStr)) + var seeker io.ReadSeeker + if req.Body != nil { + seeker = bytes.NewReader(streamToByte(req.Body)) + } + _, err := signer.Sign(req, seeker, "s3", endpoints.ApNortheast1RegionID, t) + if err != nil { + return er + } + return nil +} - h64 := base64.StdEncoding.EncodeToString(h.Sum(nil)) - auth := "AWS" + " " + s3.accessId + ":" + h64 - req.Header.Set("Authorization", auth) +func streamToByte(stream io.Reader) []byte { + buf := new(bytes.Buffer) + buf.ReadFrom(stream) + return buf.Bytes() } func (s3 *S3) resource(path string, values url.Values) string { @@ -177,10 +144,12 @@ func (s3 *S3) Put(r io.Reader, size int64, path string, md5sum []byte, contentTy req.Header.Set("Content-Type", contentType) req.Header.Set("Content-Length", fmt.Sprintf("%d", size)) - req.Header.Set("Host", req.URL.Host) req.ContentLength = size - s3.signRequest(req) + er = s3.signRequest(req) + if er != nil { + return er + } resp, er := http.DefaultClient.Do(req) if er != nil { @@ -211,7 +180,10 @@ func (s3 *S3) Get(path string) (io.ReadCloser, http.Header, error) { return nil, http.Header{}, er } - s3.signRequest(req) + er = s3.signRequest(req) + if er != nil { + return nil, http.Header{}, er + } resp, er := http.DefaultClient.Do(req) if er != nil { @@ -241,7 +213,10 @@ func (s3 *S3) Head(path string) (http.Header, error) { return http.Header{}, er } - s3.signRequest(req) + er = s3.signRequest(req) + if er != nil { + return http.Header{}, er + } resp, er := http.DefaultClient.Do(req) if er != nil { @@ -313,9 +288,10 @@ func (s3 *S3) StartMultipart(path string) (*S3Multipart, error) { return nil, er } - req.Header.Set("Host", req.URL.Host) - - s3.signRequest(req) + er = s3.signRequest(req) + if er != nil { + return nil, er + } resp, er := http.DefaultClient.Do(req) if er != nil {