Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ func (mp *S3Multipart) AddPart(r io.Reader, size int64, md5sum []byte) error {
}

req.Header.Set("Content-Length", fmt.Sprintf("%d", size))
req.Header.Set("Host", req.URL.Host)
req.Header.Set("Content-Type", "application/octet-stream")
req.ContentLength = size

Expand Down Expand Up @@ -114,7 +113,6 @@ func (mp *S3Multipart) Complete(contentType string) error {
}

req.Header.Set("Content-Length", fmt.Sprintf("%d", len(xmlBody)))
req.Header.Set("Host", req.URL.Host)
req.Header.Set("Content-Type", contentType)
req.ContentLength = int64(len(xmlBody))

Expand Down
104 changes: 40 additions & 64 deletions s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package s3

import (
"bytes"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"encoding/base64"
"encoding/xml"
"fmt"
Expand All @@ -14,9 +12,12 @@ import (
"net/http"
"net/url"
"runtime"
"sort"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/signer/v4"
)

// S3 provides a wrapper around your S3 credentials. It carries no other internal state
Expand All @@ -38,66 +39,32 @@ func NewS3(bucket, accessId, secret string) *S3 {
}
}

func (s3 *S3) signRequest(req *http.Request) {
amzHeaders := ""
resourceUrl, _ := url.Parse("/" + s3.bucket + req.URL.Path)
resource := resourceUrl.String()

/* Ugh, AWS requires us to order the parameters in a specific ordering for
* signing. Makes sense, but is annoying because a map does not have a defined
* ordering (and basically returns elements in a random order) -- so we have
* to sort by hand */
query := req.URL.Query()
if len(query) > 0 {
keys := []string{}

for k := range query {
keys = append(keys, k)
}

sort.Strings(keys)

parts := []string{}

for _, key := range keys {
vals := query[key]

for _, val := range vals {
if val == "" {
parts = append(parts, url.QueryEscape(key))

} else {
part := fmt.Sprintf("%s=%s", url.QueryEscape(key), url.QueryEscape(val))
parts = append(parts, part)
}
}
}

req.URL.RawQuery = strings.Join(parts, "&")
}

if req.URL.RawQuery != "" {
resource += "?" + req.URL.RawQuery
func (s4 *S3) signRequest(req *http.Request) (er error) {
signer := v4.Signer{
Credentials: credentials.NewStaticCredentials(s4.accessId, s4.secret, ""),
DisableURIPathEscaping: true,
}

t := time.Now()
if req.Header.Get("Date") == "" {
req.Header.Set("Date", time.Now().Format(time.RFC1123))
req.Header.Set("Date", t.Format(time.RFC1123))
}

authStr := strings.Join([]string{
strings.TrimSpace(req.Method),
req.Header.Get("Content-MD5"),
req.Header.Get("Content-Type"),
req.Header.Get("Date"),
amzHeaders + resource,
}, "\n")

h := hmac.New(sha1.New, []byte(s3.secret))
h.Write([]byte(authStr))
var seeker io.ReadSeeker
if req.Body != nil {
seeker = bytes.NewReader(streamToByte(req.Body))
}
_, err := signer.Sign(req, seeker, "s3", endpoints.ApNortheast1RegionID, t)
if err != nil {
return er
}
return nil
}

h64 := base64.StdEncoding.EncodeToString(h.Sum(nil))
auth := "AWS" + " " + s3.accessId + ":" + h64
req.Header.Set("Authorization", auth)
func streamToByte(stream io.Reader) []byte {
buf := new(bytes.Buffer)
buf.ReadFrom(stream)
return buf.Bytes()
}

func (s3 *S3) resource(path string, values url.Values) string {
Expand Down Expand Up @@ -177,10 +144,12 @@ func (s3 *S3) Put(r io.Reader, size int64, path string, md5sum []byte, contentTy

req.Header.Set("Content-Type", contentType)
req.Header.Set("Content-Length", fmt.Sprintf("%d", size))
req.Header.Set("Host", req.URL.Host)
req.ContentLength = size

s3.signRequest(req)
er = s3.signRequest(req)
if er != nil {
return er
}

resp, er := http.DefaultClient.Do(req)
if er != nil {
Expand Down Expand Up @@ -211,7 +180,10 @@ func (s3 *S3) Get(path string) (io.ReadCloser, http.Header, error) {
return nil, http.Header{}, er
}

s3.signRequest(req)
er = s3.signRequest(req)
if er != nil {
return nil, http.Header{}, er
}

resp, er := http.DefaultClient.Do(req)
if er != nil {
Expand Down Expand Up @@ -241,7 +213,10 @@ func (s3 *S3) Head(path string) (http.Header, error) {
return http.Header{}, er
}

s3.signRequest(req)
er = s3.signRequest(req)
if er != nil {
return http.Header{}, er
}

resp, er := http.DefaultClient.Do(req)
if er != nil {
Expand Down Expand Up @@ -313,9 +288,10 @@ func (s3 *S3) StartMultipart(path string) (*S3Multipart, error) {
return nil, er
}

req.Header.Set("Host", req.URL.Host)

s3.signRequest(req)
er = s3.signRequest(req)
if er != nil {
return nil, er
}

resp, er := http.DefaultClient.Do(req)
if er != nil {
Expand Down