diff --git a/backend/controllers/EndUser/MassUploadFileController.go b/backend/controllers/EndUser/MassUploadFileController.go index c62b01d..0559288 100644 --- a/backend/controllers/EndUser/MassUploadFileController.go +++ b/backend/controllers/EndUser/MassUploadFileController.go @@ -87,7 +87,6 @@ func NewMassUploadFileController( func (c *MassUploadFileController) MassUpload(ctx *gin.Context) { log.Printf("Starting mass file upload request") - // Get current user user, exists := ctx.Get("user") if !exists { ctx.JSON(http.StatusUnauthorized, gin.H{"status": "error", "error": "Unauthorized access"}) @@ -100,7 +99,6 @@ func (c *MassUploadFileController) MassUpload(ctx *gin.Context) { return } - // Parse form form, err := ctx.MultipartForm() if err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "Failed to parse form"}) @@ -113,7 +111,6 @@ func (c *MassUploadFileController) MassUpload(ctx *gin.Context) { return } - // Get encryption parameters encryptionType := services.EncryptionType(ctx.DefaultPostForm("encryption_type", string(services.StandardEncryption))) if err := c.validateEncryptionType(encryptionType, currentUser); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{ @@ -123,7 +120,6 @@ func (c *MassUploadFileController) MassUpload(ctx *gin.Context) { return } - // Parse parameters nShares := c.parseIntParam(ctx, "shares", 5) threshold := c.parseIntParam(ctx, "threshold", 3) dataShards := c.parseIntParam(ctx, "data_shards", 4) @@ -170,27 +166,24 @@ func (c *MassUploadFileController) MassUpload(ctx *gin.Context) { wg.Add(1) go func(fh *multipart.FileHeader) { defer wg.Done() - semaphore <- struct{}{} // Acquire semaphore - defer func() { <-semaphore }() // Release semaphore + semaphore <- struct{}{} + defer func() { <-semaphore }() result := c.processUpload(ctx, fh, currentUser, folderID, uploadParams) results <- result }(fileHeader) } - // Wait for all uploads to complete go func() { wg.Wait() close(results) }() - // Collect results uploadResults := make([]UploadResult, 0, len(files)) for result := range results { uploadResults = append(uploadResults, result) } - // Return results ctx.JSON(http.StatusOK, gin.H{ "status": "success", "message": "Mass file upload processing complete", @@ -352,7 +345,6 @@ func (c *MassUploadFileController) createFileRecord( return nil, fmt.Errorf("serverKey is nil") } - // Create encoded filename encryptedFileName := base64.RawURLEncoding.EncodeToString([]byte(fileHeader.Filename)) return &models.File{ @@ -433,7 +425,6 @@ func (c *MassUploadFileController) handleFolderAssignment(ctx *gin.Context, user } parsedID := uint(id) - // Verify folder exists and belongs to user - use GetFolderByID instead of GetFolder folder, err := c.folderModel.GetFolderByID(parsedID, user.ID) if err != nil { log.Printf("Folder not found or access denied: %v", err) @@ -442,7 +433,6 @@ func (c *MassUploadFileController) handleFolderAssignment(ctx *gin.Context, user return &folder.ID } - // Try to find or create "My Files" folder folders, err := c.folderModel.GetUserFolders(user.ID) if err != nil { log.Printf("Failed to get user folders: %v", err) @@ -515,7 +505,6 @@ func (c *MassUploadFileController) getFileInfo(file *models.File, params *Upload } } -// Error types for mass upload type UploadError struct { Code string `json:"code"` Message string `json:"message"` @@ -534,10 +523,8 @@ func newUploadError(code, message string, details string) *UploadError { } } -// Additional helper methods for handling large uploads func (c *MassUploadFileController) calculateBatchSize(fileCount int) int { - // Calculate optimal batch size based on available system resources - // Default to 5 concurrent uploads, but adjust based on file count + if fileCount <= 5 { return fileCount } @@ -551,7 +538,6 @@ func (c *MassUploadFileController) validateTotalSize(files []*multipart.FileHead var totalSize int64 for _, file := range files { totalSize += file.Size - // Individual file size limit (e.g., 2GB) if file.Size > 2<<30 { return 0, newUploadError( "FILE_TOO_LARGE", diff --git a/backend/controllers/EndUser/PasswordResetController.go b/backend/controllers/EndUser/PasswordResetController.go index 5e87db9..1fb8347 100644 --- a/backend/controllers/EndUser/PasswordResetController.go +++ b/backend/controllers/EndUser/PasswordResetController.go @@ -12,7 +12,6 @@ import ( type PasswordResetController struct { userModel *models.UserModel passwordHistoryModel *models.PasswordHistoryModel - keyRotationModel *models.KeyRotationModel keyFragmentModel *models.KeyFragmentModel fileModel *models.FileModel } @@ -25,14 +24,12 @@ type PasswordResetRequest struct { func NewPasswordResetController( userModel *models.UserModel, passwordHistoryModel *models.PasswordHistoryModel, - keyRotationModel *models.KeyRotationModel, keyFragmentModel *models.KeyFragmentModel, fileModel *models.FileModel, ) *PasswordResetController { return &PasswordResetController{ userModel: userModel, passwordHistoryModel: passwordHistoryModel, - keyRotationModel: keyRotationModel, keyFragmentModel: keyFragmentModel, fileModel: fileModel, } @@ -63,7 +60,6 @@ func (c *PasswordResetController) ResetPassword(ctx *gin.Context) { return } - // Call the model method to handle all database operations err := c.userModel.ResetPasswordWithFragments( endUser.ID, req.CurrentPassword, @@ -76,7 +72,6 @@ func (c *PasswordResetController) ResetPassword(ctx *gin.Context) { if err != nil { log.Printf("Password reset failed for user %d: %v", endUser.ID, err) - // Map common errors to appropriate HTTP status codes status := http.StatusInternalServerError switch err.Error() { case "current password is incorrect": diff --git a/backend/controllers/EndUser/ShareFileController.go b/backend/controllers/EndUser/ShareFileController.go index aba4759..cb184d6 100644 --- a/backend/controllers/EndUser/ShareFileController.go +++ b/backend/controllers/EndUser/ShareFileController.go @@ -6,22 +6,28 @@ import ( "log" "net/http" "net/url" + "os" "safesplit/models" "safesplit/services" + "strconv" "strings" + "time" "github.com/gin-gonic/gin" ) type ShareFileController struct { - fileModel *models.FileModel - fileShareModel *models.FileShareModel - keyFragmentModel *models.KeyFragmentModel - encryptionService *services.EncryptionService - activityLogModel *models.ActivityLogModel - rsService *services.ReedSolomonService - userModel *models.UserModel - serverKeyModel *models.ServerMasterKeyModel + fileModel *models.FileModel + fileShareModel *models.FileShareModel + keyFragmentModel *models.KeyFragmentModel + encryptionService *services.EncryptionService + activityLogModel *models.ActivityLogModel + rsService *services.ReedSolomonService + userModel *models.UserModel + serverKeyModel *models.ServerMasterKeyModel + twoFactorService *services.TwoFactorAuthService + emailService *services.SMTPEmailService + compressionService *services.CompressionService } func NewShareFileController( @@ -33,340 +39,443 @@ func NewShareFileController( rsService *services.ReedSolomonService, userModel *models.UserModel, serverKeyModel *models.ServerMasterKeyModel, + twoFactorService *services.TwoFactorAuthService, + emailService *services.SMTPEmailService, + compressionService *services.CompressionService, ) *ShareFileController { return &ShareFileController{ - fileModel: fileModel, - fileShareModel: fileShareModel, - keyFragmentModel: keyFragmentModel, - encryptionService: encryptionService, - activityLogModel: activityLogModel, - rsService: rsService, - userModel: userModel, - serverKeyModel: serverKeyModel, + fileModel: fileModel, + fileShareModel: fileShareModel, + keyFragmentModel: keyFragmentModel, + encryptionService: encryptionService, + activityLogModel: activityLogModel, + rsService: rsService, + userModel: userModel, + serverKeyModel: serverKeyModel, + twoFactorService: twoFactorService, + emailService: emailService, + compressionService: compressionService, } } type CreateShareRequest struct { - FileID uint `json:"file_id" binding:"required"` - Password string `json:"password" binding:"required,min=6"` + ShareType models.ShareType `json:"share_type" binding:"required"` + Password string `json:"password" binding:"required,min=6"` + Email string `json:"email,omitempty"` } type AccessShareRequest struct { Password string `json:"password" binding:"required"` + Email string `json:"email,omitempty"` +} + +type TwoFactorRequest struct { + Code string `json:"code" binding:"required"` + Password string `json:"password" binding:"required"` } func (c *ShareFileController) CreateShare(ctx *gin.Context) { - log.Printf("Received normal share creation request for file ID: %v", ctx.Param("id")) + // Get file ID from URL parameter + fileID := ctx.Param("id") + if fileID == "" { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "File ID is required"}) + return + } + + // Convert string ID to uint + id, err := strconv.ParseUint(fileID, 10, 64) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file ID"}) + return + } + var req CreateShareRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.JSON(http.StatusBadRequest, gin.H{ - "status": "error", - "error": "Invalid request data", - }) + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"}) return } - // Get current user - user, exists := ctx.Get("user") - if !exists { - ctx.JSON(http.StatusUnauthorized, gin.H{ - "status": "error", - "error": "Unauthorized access", - }) + if req.ShareType == models.RecipientShare && req.Email == "" { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Email required for recipient share"}) return } - currentUser := user.(*models.User) - // Get file and verify ownership - file, err := c.fileModel.GetFileForDownload(req.FileID, currentUser.ID) + user := ctx.MustGet("user").(*models.User) + file, err := c.fileModel.GetFileForDownload(uint(id), user.ID) if err != nil { - ctx.JSON(http.StatusNotFound, gin.H{ - "status": "error", - "error": "File not found or access denied", - }) + ctx.JSON(http.StatusNotFound, gin.H{"error": "File not found"}) return } - // Derive user's KEK - kek, err := services.DeriveKeyEncryptionKey(currentUser.Password, currentUser.MasterKeySalt) + kek, err := services.DeriveKeyEncryptionKey(user.Password, user.MasterKeySalt) if err != nil { - log.Printf("Failed to derive KEK: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process encryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Encryption failed"}) return } - // Decrypt user's master key - decryptedMasterKey, err := services.DecryptMasterKey( - currentUser.EncryptedMasterKey, - kek, - currentUser.MasterKeyNonce, - ) + decryptedMasterKey, err := services.DecryptMasterKey(user.EncryptedMasterKey, kek, user.MasterKeyNonce) if err != nil { - log.Printf("Failed to decrypt master key: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process encryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Decryption failed"}) return } - // Use first 32 bytes of decrypted master key userMasterKey := decryptedMasterKey[:32] - - // Get fragments fragments, err := c.keyFragmentModel.GetUserFragmentsForFile(file.ID) if err != nil || len(fragments) == 0 { - log.Printf("Failed to retrieve key fragments for file %d: %v", file.ID, err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to retrieve key fragments", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get fragments"}) return } - log.Printf("Creating share for file %d with %d fragments", file.ID, len(fragments)) - // Get first fragment and remember its index userFragment := fragments[0] - log.Printf("Selected user fragment with index %d for sharing", userFragment.FragmentIndex) - - // Decrypt the fragment we'll share using master key decryptedFragment, err := services.DecryptMasterKey( userFragment.Data, userMasterKey, userFragment.EncryptionNonce, ) if err != nil { - log.Printf("Failed to decrypt fragment: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process share creation", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Fragment decryption failed"}) return } - // Encrypt decrypted fragment with share password encryptedFragment, err := c.encryptionService.EncryptKeyFragment( decryptedFragment, []byte(req.Password), ) if err != nil { - log.Printf("Failed to encrypt key fragment: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process share encryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Fragment encryption failed"}) return } - // Create share record with original fragment index share := &models.FileShare{ FileID: file.ID, - SharedBy: currentUser.ID, + SharedBy: user.ID, EncryptedKeyFragment: encryptedFragment, - FragmentIndex: userFragment.FragmentIndex, // Store original index + FragmentIndex: userFragment.FragmentIndex, IsActive: true, + ShareType: req.ShareType, + Email: req.Email, } if err := c.fileShareModel.CreateFileShare(share, req.Password); err != nil { - log.Printf("Failed to create file share: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to create share", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Share creation failed"}) return } - if err := c.activityLogModel.LogActivity(&models.ActivityLog{ - UserID: currentUser.ID, + // Get base URL from environment variable + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "http://localhost:8080" + } + + // Create the complete share URL + shareURL := fmt.Sprintf("%s/api/files/share/%s", baseURL, share.ShareLink) + + if req.ShareType == models.RecipientShare { + // Get base URL from environment variable + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "http://localhost:3000" + } + + // Create the frontend share URL (not the API URL) + shareURL := fmt.Sprintf("%s/protected-share/%s", baseURL, share.ShareLink) + + emailBody := fmt.Sprintf(`Hello, + + You have received a secure file share from %s. + + File: %s + Access Link: %s + + This link requires a password and email verification to access. + Please use the same email address this message was sent to when accessing the file. + + Best regards, + SafeSplit Team`, user.Username, file.OriginalName, shareURL) + + if err := c.emailService.SendEmail( + req.Email, + "Secure File Share Received", + emailBody, + ); err != nil { + log.Printf("Failed to send email: %v", err) + } + } + + c.activityLogModel.LogActivity(&models.ActivityLog{ + UserID: user.ID, ActivityType: "share", FileID: &file.ID, IPAddress: ctx.ClientIP(), Status: "success", - }); err != nil { - log.Printf("Failed to log share activity: %v", err) - } + Details: fmt.Sprintf("Created %s share", req.ShareType), + }) ctx.JSON(http.StatusOK, gin.H{ "status": "success", "data": gin.H{ - "share_link": share.ShareLink, + "share_link": shareURL, + "raw_link": share.ShareLink, + "requires_2fa": req.ShareType == models.RecipientShare, }, }) } func (c *ShareFileController) AccessShare(ctx *gin.Context) { shareLink := ctx.Param("shareLink") - log.Printf("Received normal share access request for link: %s", shareLink) + + // For GET requests, return file info and requirements + if ctx.Request.Method == "GET" { + share, err := c.fileShareModel.GetShareByLink(shareLink) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Invalid share"}) + return + } + + // Get file info + file, err := c.fileModel.GetFileByID(share.FileID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{ + "status": "error", + "error": "File not found"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": gin.H{ + "requires_password": true, + "requires_2fa": share.ShareType == models.RecipientShare, + "recipient_share": share.ShareType == models.RecipientShare, + "file_name": file.OriginalName, + "file_size": file.Size, + "mime_type": file.MimeType, + "created_at": share.CreatedAt, + "expires_at": share.ExpiresAt, + "download_count": share.DownloadCount, + "max_downloads": share.MaxDownloads, + }, + }) + return + } + + // Handle POST request var req AccessShareRequest if err := ctx.ShouldBindJSON(&req); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{ "status": "error", - "error": "Invalid password", - }) + "error": "Invalid request"}) return } - // Get and validate share - share, err := c.fileShareModel.ValidateShare(shareLink, req.Password) + share, err := c.fileShareModel.GetShareByLink(shareLink) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{ "status": "error", - "error": "Invalid share link or password", - }) + "error": "Invalid share"}) return } - log.Printf("Processing share access for link: %s", shareLink) - - // Get file metadata + // Get file info early to use in verification file, err := c.fileModel.GetFileByID(share.FileID) if err != nil { - log.Printf("Failed to get file %d: %v", share.FileID, err) ctx.JSON(http.StatusNotFound, gin.H{ "status": "error", - "error": "File not found", + "error": "File not found"}) + return + } + + // Check if share has expired + if share.ExpiresAt != nil && time.Now().After(*share.ExpiresAt) { + ctx.JSON(http.StatusForbidden, gin.H{ + "status": "error", + "error": "Share link has expired"}) + return + } + + // Check if maximum downloads reached + if share.MaxDownloads != nil && share.DownloadCount >= *share.MaxDownloads { + ctx.JSON(http.StatusForbidden, gin.H{ + "status": "error", + "error": "Maximum number of downloads reached"}) + return + } + + // Check if share is still active + if !share.IsActive { + ctx.JSON(http.StatusForbidden, gin.H{ + "status": "error", + "error": "Share link is no longer active"}) + return + } + + // Validate share based on type + if share.ShareType == models.RecipientShare { + // Validate password only + share, validationErr := c.fileShareModel.ValidateRecipientShare(shareLink, req.Password) + if validationErr != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Invalid password"}) + return + } + + // Send 2FA to the email associated with the share + if err := c.twoFactorService.SendShareVerificationToken(share.ID, share.Email, file.OriginalName); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "status": "error", + "error": "Failed to send 2FA code"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "message": "2FA code sent to registered email", + "data": gin.H{ + "share_id": share.ID, + }, }) return + } else { + // For normal shares, just validate password + share, err := c.fileShareModel.ValidateShare(shareLink, req.Password) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Invalid password"}) + return + } + c.processFileAccess(ctx, share, req.Password) + } +} +func (c *ShareFileController) Verify2FAAndDownload(ctx *gin.Context) { + shareLink := ctx.Param("shareLink") + var req TwoFactorRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "status": "error", + "error": "Invalid request data"}) + return } - // Get server fragments - serverFragments, err := c.keyFragmentModel.GetServerFragmentsForFile(share.FileID) + share, err := c.fileShareModel.GetShareByLink(shareLink) if err != nil { - log.Printf("Failed to get server fragments: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ + ctx.JSON(http.StatusUnauthorized, gin.H{ "status": "error", - "error": "Failed to process file access", - }) + "error": "Invalid share"}) return } - log.Printf("Retrieved %d server fragments", len(serverFragments)) + // Verify share type + if share.ShareType != models.RecipientShare { + ctx.JSON(http.StatusBadRequest, gin.H{ + "status": "error", + "error": "2FA verification only required for recipient shares"}) + return + } - // Verify we have enough fragments - if len(serverFragments)+1 < int(file.Threshold) { // +1 for shared fragment - log.Printf("Insufficient fragments: have %d server + 1 shared, need %d", - len(serverFragments), file.Threshold) - ctx.JSON(http.StatusInternalServerError, gin.H{ + // Verify 2FA code + if err := c.twoFactorService.VerifyToken(share.ID, req.Code); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ "status": "error", - "error": "Insufficient fragments to reconstruct file", - }) + "error": "Invalid 2FA code"}) + return + } + + // Process file download + c.processFileAccess(ctx, share, req.Password) +} + +func (c *ShareFileController) processFileAccess(ctx *gin.Context, share *models.FileShare, password string) { + file, err := c.fileModel.GetFileByID(share.FileID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "File not found"}) + return + } + + serverFragments, err := c.keyFragmentModel.GetServerFragmentsForFile(share.FileID) + if err != nil || len(serverFragments)+1 < int(file.Threshold) { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Insufficient fragments"}) return } - // Get server key for decrypting server fragments serverKey, err := c.serverKeyModel.GetActive() if err != nil { - log.Printf("Failed to get server key: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process decryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Server key error"}) return } serverKeyData, err := c.serverKeyModel.GetServerKey(serverKey.KeyID) if err != nil { - log.Printf("Failed to get server key data: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to get server key", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Server key data error"}) return } - // Decrypt shared fragment sharedDecryptedFragment, err := c.encryptionService.DecryptKeyFragment( share.EncryptedKeyFragment, - []byte(req.Password), + []byte(password), ) if err != nil { - log.Printf("Failed to decrypt shared fragment: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process file decryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Fragment decryption failed"}) return } - // We need threshold number of unique shares shares := make([]services.KeyShare, file.Threshold) usedIndices := make(map[int]bool) - // Add the shared fragment first with its original index shares[0] = services.KeyShare{ - Index: share.FragmentIndex, // Use stored original index + Index: share.FragmentIndex, Value: hex.EncodeToString(sharedDecryptedFragment), } usedIndices[share.FragmentIndex] = true - log.Printf("Added shared fragment with original index %d", share.FragmentIndex) - // Add server fragments with unique indices - sharesAdded := uint(1) // Start at 1 since we added shared fragment + sharesAdded := uint(1) for i := 0; i < len(serverFragments) && sharesAdded < file.Threshold; i++ { fragment := serverFragments[i] - - // Skip if we've used this index if usedIndices[fragment.FragmentIndex] { continue } - // Decrypt server fragment decryptedFragment, err := services.DecryptMasterKey( fragment.Data, serverKeyData, fragment.EncryptionNonce, ) if err != nil { - log.Printf("Failed to decrypt server fragment %d: %v", i, err) continue } shares[sharesAdded] = services.KeyShare{ - Index: fragment.FragmentIndex, // Use original server fragment index + Index: fragment.FragmentIndex, Value: hex.EncodeToString(decryptedFragment), NodeIndex: fragment.NodeIndex, FragmentPath: fragment.FragmentPath, } usedIndices[fragment.FragmentIndex] = true - log.Printf("Added server fragment %d with original index %d", i, fragment.FragmentIndex) sharesAdded++ } - // Verify we have enough unique shares if sharesAdded < file.Threshold { - log.Printf("Failed to get enough unique shares: have %d, need %d", sharesAdded, file.Threshold) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to get enough unique shares", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Insufficient unique shares"}) return } - // Get encrypted file data var encryptedData []byte var retrievalErr error if file.IsSharded { - log.Printf("Retrieving sharded data for file %d", file.ID) encryptedData, retrievalErr = c.getShardedData(file) } else { - log.Printf("Reading file content from path: %s", file.FilePath) encryptedData, retrievalErr = c.fileModel.ReadFileContent(file.FilePath) } if retrievalErr != nil { - log.Printf("Failed to retrieve file data: %v", retrievalErr) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to read file data", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "File retrieval failed"}) return } - // Decrypt the file decryptedData, err := c.encryptionService.DecryptFileWithType( encryptedData, file.EncryptionIV, @@ -376,72 +485,63 @@ func (c *ShareFileController) AccessShare(ctx *gin.Context) { file.EncryptionType, ) if err != nil { - log.Printf("Failed to decrypt file data: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to decrypt file", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "File decryption failed"}) return } - // Log share access - if err := c.activityLogModel.LogActivity(&models.ActivityLog{ + // Handle decompression if the file is compressed + if file.IsCompressed { + log.Printf("Decompressing data for file ID: %d", file.ID) + decryptedData, err = c.compressionService.Decompress(decryptedData) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to decompress file"}) + return + } + } + + if err := c.fileShareModel.IncrementDownloadCount(share.ID); err != nil { + log.Printf("Failed to increment download count: %v", err) + } + + c.activityLogModel.LogActivity(&models.ActivityLog{ UserID: share.SharedBy, ActivityType: "download", FileID: &file.ID, IPAddress: ctx.ClientIP(), Status: "success", - Details: fmt.Sprintf("Shared file download using %d fragments", file.Threshold), - }); err != nil { - log.Printf("Failed to log share download activity: %v", err) - } + Details: fmt.Sprintf("Download with %d fragments", file.Threshold), + }) - // Send file response c.sendFileResponse(ctx, file, decryptedData) } - func (c *ShareFileController) getShardedData(file *models.File) ([]byte, error) { fileShards, err := c.rsService.RetrieveShards(file.ID, int(file.DataShardCount+file.ParityShardCount)) if err != nil { return nil, fmt.Errorf("failed to retrieve shards: %w", err) } - validShards := 0 - for i, shard := range fileShards.Shards { - if shard != nil { - validShards++ - log.Printf("Shard %d: %d bytes", i, len(shard)) - } else { - log.Printf("Shard %d: Missing", i) - } - } - if !c.rsService.ValidateShards(fileShards.Shards, int(file.DataShardCount)) { - return nil, fmt.Errorf("insufficient shards for reconstruction: have %d, need %d", - validShards, file.DataShardCount) + return nil, fmt.Errorf("insufficient shards for reconstruction") } - reconstructed, err := c.rsService.ReconstructFile(fileShards.Shards, + return c.rsService.ReconstructFile(fileShards.Shards, int(file.DataShardCount), int(file.ParityShardCount)) - if err != nil { - return nil, fmt.Errorf("failed to reconstruct file: %w", err) - } - - log.Printf("Successfully reconstructed file data: %d bytes", len(reconstructed)) - return reconstructed, nil } func (c *ShareFileController) sendFileResponse(ctx *gin.Context, file *models.File, data []byte) { - sanitizedFilename := strings.ReplaceAll(file.OriginalName, `"`, `\"`) - encodedFilename := url.QueryEscape(sanitizedFilename) - - ctx.Header("Access-Control-Expose-Headers", "Content-Disposition, Content-Type, Content-Length") - ctx.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, - sanitizedFilename, - encodedFilename)) + escapedName := strings.ReplaceAll(file.OriginalName, `"`, `\"`) + utf8Name := url.PathEscape(file.OriginalName) + ctx.Header("Content-Disposition", fmt.Sprintf( + `attachment; filename="%s"; filename*=UTF-8''%s`, + escapedName, + utf8Name, + )) ctx.Header("Content-Type", file.MimeType) ctx.Header("Content-Length", fmt.Sprintf("%d", len(data))) - + ctx.Header("X-Original-Filename", escapedName) + ctx.Header("Access-Control-Expose-Headers", "Content-Disposition, Content-Type, Content-Length, X-Original-Filename") + ctx.Header("Content-Description", "File Transfer") + ctx.Header("Content-Transfer-Encoding", "binary") log.Printf("Sending file response: %s (Size: %d bytes)", file.OriginalName, len(data)) ctx.Data(http.StatusOK, file.MimeType, data) } diff --git a/backend/controllers/EndUser/TwoFactorAuthController.go b/backend/controllers/EndUser/TwoFactorAuthController.go index 3635fb7..28966be 100644 --- a/backend/controllers/EndUser/TwoFactorAuthController.go +++ b/backend/controllers/EndUser/TwoFactorAuthController.go @@ -1,90 +1,223 @@ package EndUser import ( - "fmt" - "log" - "net/http" - "safesplit/models" + "fmt" + "log" + "net/http" + "safesplit/models" + "safesplit/services" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) type TwoFactorController struct { - userModel *models.UserModel + userModel *models.UserModel + twoFactorService *services.TwoFactorAuthService } -func NewTwoFactorController(userModel *models.UserModel) *TwoFactorController { - return &TwoFactorController{ - userModel: userModel, - } +func NewTwoFactorController(userModel *models.UserModel, twoFactorService *services.TwoFactorAuthService) *TwoFactorController { + return &TwoFactorController{ + userModel: userModel, + twoFactorService: twoFactorService, + } } -func (c *TwoFactorController) EnableEmailTwoFactor(ctx *gin.Context) { - userID, exists := ctx.Get("user_id") - if !exists { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) - return - } - - uid, ok := userID.(uint) - if !ok { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) - return - } - - log.Printf("Enabling 2FA for user ID: %d", uid) - - if err := c.userModel.EnableEmailTwoFactor(uid); err != nil { - ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to enable 2FA: %v", err)}) - return - } - - ctx.JSON(http.StatusOK, gin.H{"message": "2FA enabled successfully"}) +func (c *TwoFactorController) GetTwoFactorStatus(ctx *gin.Context) { + userID, exists := ctx.Get("user_id") + if !exists { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) + return + } + + uid, ok := userID.(uint) + if !ok { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) + return + } + + user, err := c.userModel.FindByID(uid) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "two_factor_enabled": user.TwoFactorEnabled, + }) } -func (c *TwoFactorController) DisableEmailTwoFactor(ctx *gin.Context) { - userID, exists := ctx.Get("user_id") - if !exists { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) - return - } - - uid, ok := userID.(uint) - if !ok { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) - return - } - - log.Printf("Disabling 2FA for user ID: %d", uid) - - if err := c.userModel.DisableEmailTwoFactor(uid); err != nil { - ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to disable 2FA: %v", err)}) - return - } +func (c *TwoFactorController) InitiateEnable2FA(ctx *gin.Context) { + userID, exists := ctx.Get("user_id") + if !exists { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) + return + } + + uid, ok := userID.(uint) + if !ok { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) + return + } + + // Get user's email + user, err := c.userModel.FindByID(uid) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + // Verify 2FA is not already enabled + if user.TwoFactorEnabled { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "2FA is already enabled"}) + return + } + + log.Printf("Initiating 2FA enable for user ID: %d", uid) + + // Send verification code + if err := c.twoFactorService.SendTwoFactorToken(uid, user.Email); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to send verification code: %v", err)}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "Verification code sent to your email"}) +} - ctx.JSON(http.StatusOK, gin.H{"message": "2FA disabled successfully"}) +func (c *TwoFactorController) VerifyAndEnable2FA(ctx *gin.Context) { + userID, exists := ctx.Get("user_id") + if !exists { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) + return + } + + uid, ok := userID.(uint) + if !ok { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) + return + } + + var req struct { + Code string `json:"code" binding:"required"` + } + + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + // Verify 2FA is not already enabled + user, err := c.userModel.FindByID(uid) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + if user.TwoFactorEnabled { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "2FA is already enabled"}) + return + } + + // Verify the code + if err := c.twoFactorService.VerifyToken(uid, req.Code); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired verification code"}) + return + } + + log.Printf("Enabling 2FA for user ID: %d after verification", uid) + + // Enable 2FA + if err := c.userModel.EnableEmailTwoFactor(uid); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to enable 2FA: %v", err)}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "2FA enabled successfully"}) } -func (c *TwoFactorController) GetTwoFactorStatus(ctx *gin.Context) { - userID, exists := ctx.Get("user_id") - if !exists { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) - return - } - - uid, ok := userID.(uint) - if !ok { - ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) - return - } - - user, err := c.userModel.FindByID(uid) - if err != nil { - ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) - return - } - - ctx.JSON(http.StatusOK, gin.H{ - "two_factor_enabled": user.TwoFactorEnabled, - }) +func (c *TwoFactorController) InitiateDisable2FA(ctx *gin.Context) { + userID, exists := ctx.Get("user_id") + if !exists { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) + return + } + + uid, ok := userID.(uint) + if !ok { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) + return + } + + // Get user's email + user, err := c.userModel.FindByID(uid) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + // Verify 2FA is enabled + if !user.TwoFactorEnabled { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "2FA is not enabled"}) + return + } + + log.Printf("Initiating 2FA disable for user ID: %d", uid) + + // Send verification code + if err := c.twoFactorService.SendTwoFactorToken(uid, user.Email); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to send verification code: %v", err)}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "Verification code sent to your email"}) } + +func (c *TwoFactorController) VerifyAndDisable2FA(ctx *gin.Context) { + userID, exists := ctx.Get("user_id") + if !exists { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "User ID not found in context"}) + return + } + + uid, ok := userID.(uint) + if !ok { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID type"}) + return + } + + var req struct { + Code string `json:"code" binding:"required"` + } + + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + // Verify 2FA is still enabled + user, err := c.userModel.FindByID(uid) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + if !user.TwoFactorEnabled { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "2FA is not enabled"}) + return + } + + // Verify the code + if err := c.twoFactorService.VerifyToken(uid, req.Code); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired verification code"}) + return + } + + log.Printf("Disabling 2FA for user ID: %d after verification", uid) + + // Disable 2FA + if err := c.userModel.DisableEmailTwoFactor(uid); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to disable 2FA: %v", err)}) + return + } + + ctx.JSON(http.StatusOK, gin.H{"message": "2FA disabled successfully"}) +} \ No newline at end of file diff --git a/backend/controllers/EndUser/UnarchiveFileController.go b/backend/controllers/EndUser/UnarchiveFileController.go new file mode 100644 index 0000000..c422daf --- /dev/null +++ b/backend/controllers/EndUser/UnarchiveFileController.go @@ -0,0 +1,61 @@ +package EndUser + +import ( + "net/http" + "safesplit/models" + "strconv" + + "github.com/gin-gonic/gin" +) + +type UnarchiveFileController struct { + fileModel *models.FileModel +} + +func NewUnarchiveFileController(fileModel *models.FileModel) *UnarchiveFileController { + return &UnarchiveFileController{ + fileModel: fileModel, + } +} + +func (c *UnarchiveFileController) Unarchive(ctx *gin.Context) { + // Get user ID from context + userID := ctx.GetUint("user_id") + if userID == 0 { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Unauthorized access", + }) + return + } + + // Parse file ID from URL parameter + fileID, err := strconv.ParseUint(ctx.Param("id"), 10, 32) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "status": "error", + "error": "Invalid file ID", + }) + return + } + + // Call the model method to unarchive the file + err = c.fileModel.UnarchiveFile(uint(fileID), userID, ctx.ClientIP()) + if err != nil { + status := http.StatusInternalServerError + if err.Error() == "file not found or not archived" { + status = http.StatusNotFound + } + ctx.JSON(status, gin.H{ + "status": "error", + "error": err.Error(), + }) + return + } + + // Return success response + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "message": "File unarchived successfully", + }) +} \ No newline at end of file diff --git a/backend/controllers/EndUser/massUnarchiveFileController.go b/backend/controllers/EndUser/massUnarchiveFileController.go new file mode 100644 index 0000000..a3be774 --- /dev/null +++ b/backend/controllers/EndUser/massUnarchiveFileController.go @@ -0,0 +1,56 @@ +package EndUser + +import ( + "net/http" + "safesplit/models" + + "github.com/gin-gonic/gin" +) + +type MassUnarchiveFileController struct { + fileModel *models.FileModel +} + +func NewMassUnarchiveFileController(fileModel *models.FileModel) *MassUnarchiveFileController { + return &MassUnarchiveFileController{ + fileModel: fileModel, + } +} + +func (c *MassUnarchiveFileController) Unarchive(ctx *gin.Context) { + userID := ctx.GetUint("user_id") + if userID == 0 { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Unauthorized access", + }) + return + } + + var request struct { + FileIDs []uint `json:"file_ids" binding:"required"` + } + + if err := ctx.ShouldBindJSON(&request); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "status": "error", + "error": "Invalid request body", + }) + return + } + + results := make(map[uint]string) + for _, fileID := range request.FileIDs { + err := c.fileModel.UnarchiveFile(fileID, userID, ctx.ClientIP()) + if err != nil { + results[fileID] = err.Error() + } else { + results[fileID] = "success" + } + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "results": results, + }) +} \ No newline at end of file diff --git a/backend/controllers/LoginController.go b/backend/controllers/LoginController.go index 51df6b8..2d85b7c 100644 --- a/backend/controllers/LoginController.go +++ b/backend/controllers/LoginController.go @@ -1,113 +1,219 @@ package controllers import ( - "net/http" - "safesplit/config" - "safesplit/models" - "github.com/gin-gonic/gin" - "gorm.io/gorm" + "fmt" + "net/http" + "safesplit/config" + "safesplit/models" + "strings" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" ) type LoginController struct { - userModel *models.UserModel - billingModel *models.BillingModel + userModel *models.UserModel + billingModel *models.BillingModel + activityLogger *models.ActivityLogModel } type LoginRequest struct { - Email string `json:"email" binding:"required"` - Password string `json:"password" binding:"required"` - TwoFactorCode string `json:"two_factor_code"` + Email string `json:"email" binding:"required"` + Password string `json:"password" binding:"required"` + TwoFactorCode string `json:"two_factor_code"` } type UserResponse struct { - User *models.User `json:"user"` - BillingProfile *models.BillingProfile `json:"billing_profile,omitempty"` + User *models.User `json:"user"` + BillingProfile *models.BillingProfile `json:"billing_profile,omitempty"` } -func NewLoginController(userModel *models.UserModel, billingModel *models.BillingModel) *LoginController { - return &LoginController{ - userModel: userModel, - billingModel: billingModel, - } +func NewLoginController(userModel *models.UserModel, billingModel *models.BillingModel, activityLogger *models.ActivityLogModel) *LoginController { + return &LoginController{ + userModel: userModel, + billingModel: billingModel, + activityLogger: activityLogger, + } } func (c *LoginController) Login(ctx *gin.Context) { - var loginReq LoginRequest - if err := ctx.ShouldBindJSON(&loginReq); err != nil { - ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - user, err := c.userModel.Authenticate(loginReq.Email, loginReq.Password) - if err != nil { - ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"}) - return - } - - if user.TwoFactorEnabled { - if loginReq.TwoFactorCode == "" { - if err := c.userModel.InitiateEmailTwoFactor(user.ID); err != nil { - ctx.JSON(http.StatusInternalServerError, gin.H{ - "error": "Failed to send 2FA code", - }) - return - } - ctx.JSON(http.StatusAccepted, gin.H{ - "message": "2FA required", - "requires_2fa": true, - "user_id": user.ID, - }) - return - } - - if err := c.userModel.VerifyEmailTwoFactor(user.ID, loginReq.TwoFactorCode); err != nil { - ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid 2FA code"}) - return - } - } - - token, err := config.GenerateToken(user.ID, user.Role) - if err != nil { - ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error generating token"}) - return - } - - user.Password = "" - billingProfile, err := c.billingModel.GetUserBillingProfile(user.ID) - if err != nil && err != gorm.ErrRecordNotFound { - ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error fetching billing details"}) - return - } - - ctx.JSON(http.StatusOK, gin.H{ - "token": token, - "data": UserResponse{ - User: user, - BillingProfile: billingProfile, - }, - }) + var loginReq LoginRequest + if err := ctx.ShouldBindJSON(&loginReq); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + user, err := c.userModel.Authenticate(loginReq.Email, loginReq.Password) + if err != nil { + var lockedUser *models.User + lockedUser, _ = c.userModel.FindByEmail(loginReq.Email) + + if lockedUser != nil { + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: lockedUser.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "failure", + ErrorMessage: err.Error(), + Details: fmt.Sprintf("Failed login attempt from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + } + + if lockedUser != nil && lockedUser.AccountLockedUntil != nil && lockedUser.AccountLockedUntil.After(time.Now()) { + remainingTime := int(lockedUser.AccountLockedUntil.Sub(time.Now()).Minutes()) + ctx.JSON(http.StatusTooManyRequests, gin.H{ + "error": fmt.Sprintf("Account locked for %d minutes", remainingTime), + "status": "locked", + "locked_until": lockedUser.AccountLockedUntil, + "remaining_minutes": remainingTime, + }) + return + } + + if strings.Contains(err.Error(), "attempts remaining") { + parts := strings.Split(err.Error(), " ") + for i, part := range parts { + if part == "remaining" && i > 0 { + attempts := parts[i-1] + ctx.JSON(http.StatusUnauthorized, gin.H{ + "error": err.Error(), + "status": "failed", + "remaining_attempts": attempts, + }) + return + } + } + } + + if lockedUser != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "error": err.Error(), + "status": "failed", + }) + return + } + + ctx.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid credentials", + "status": "failed", + }) + return + } + + // Handle 2FA if enabled + if user.TwoFactorEnabled { + if loginReq.TwoFactorCode == "" { + if err := c.userModel.InitiateEmailTwoFactor(user.ID); err != nil { + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: user.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "failure", + ErrorMessage: "Failed to send 2FA code", + Details: fmt.Sprintf("2FA initiation failed from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to send 2FA code", + }) + return + } + ctx.JSON(http.StatusAccepted, gin.H{ + "message": "2FA required", + "requires_2fa": true, + "user_id": user.ID, + }) + return + } + + if err := c.userModel.VerifyEmailTwoFactor(user.ID, loginReq.TwoFactorCode); err != nil { + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: user.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "failure", + ErrorMessage: "Invalid 2FA code", + Details: fmt.Sprintf("Invalid 2FA code from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid 2FA code"}) + return + } + } + + token, err := config.GenerateToken(user.ID, user.Role) + if err != nil { + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: user.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "failure", + ErrorMessage: "Error generating token", + Details: fmt.Sprintf("Token generation failed from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error generating token"}) + return + } + + // Get billing profile + user.Password = "" + billingProfile, err := c.billingModel.GetUserBillingProfile(user.ID) + if err != nil && err != gorm.ErrRecordNotFound { + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: user.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "failure", + ErrorMessage: "Error fetching billing details", + Details: fmt.Sprintf("Billing profile fetch failed from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error fetching billing details"}) + return + } + + // Log successful login + c.activityLogger.LogActivity(&models.ActivityLog{ + UserID: user.ID, + ActivityType: "login", + IPAddress: ctx.ClientIP(), + Status: "success", + Details: fmt.Sprintf("Successful login from IP: %s", ctx.ClientIP()), + CreatedAt: time.Now(), + }) + + ctx.JSON(http.StatusOK, gin.H{ + "token": token, + "data": UserResponse{ + User: user, + BillingProfile: billingProfile, + }, + }) } func (c *LoginController) GetMe(ctx *gin.Context) { - userID := ctx.GetUint("user_id") - user, err := c.userModel.FindByID(userID) - if err != nil { - ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) - return - } - - user.Password = "" - billingProfile, err := c.billingModel.GetUserBillingProfile(userID) - if err != nil && err != gorm.ErrRecordNotFound { - ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error fetching billing details"}) - return - } - - ctx.JSON(http.StatusOK, gin.H{ - "data": UserResponse{ - User: user, - BillingProfile: billingProfile, - }, - "role": user.Role, - }) -} \ No newline at end of file + userID := ctx.GetUint("user_id") + user, err := c.userModel.FindByID(userID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + user.Password = "" + billingProfile, err := c.billingModel.GetUserBillingProfile(userID) + if err != nil && err != gorm.ErrRecordNotFound { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error fetching billing details"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "data": UserResponse{ + User: user, + BillingProfile: billingProfile, + }, + "role": user.Role, + }) +} diff --git a/backend/controllers/LogoutController.go b/backend/controllers/LogoutController.go index e04e35f..93262df 100644 --- a/backend/controllers/LogoutController.go +++ b/backend/controllers/LogoutController.go @@ -16,8 +16,6 @@ func NewLogoutController(userModel *models.UserModel) *LogoutController { } func (c *LogoutController) Logout(ctx *gin.Context) { - // Since we're using JWT, just return success - // Frontend will handle token removal ctx.JSON(http.StatusOK, gin.H{ "message": "Successfully logged out", }) diff --git a/backend/controllers/PremiumUser/AdvancedShareFileController.go b/backend/controllers/PremiumUser/AdvancedShareFileController.go index 511c2b1..68849a9 100644 --- a/backend/controllers/PremiumUser/AdvancedShareFileController.go +++ b/backend/controllers/PremiumUser/AdvancedShareFileController.go @@ -6,6 +6,8 @@ import ( "log" "net/http" "net/url" + "os" + "strconv" "strings" "time" @@ -16,14 +18,17 @@ import ( ) type ShareFileController struct { - fileModel *models.FileModel - fileShareModel *models.FileShareModel - keyFragmentModel *models.KeyFragmentModel - encryptionService *services.EncryptionService - activityLogModel *models.ActivityLogModel - rsService *services.ReedSolomonService - userModel *models.UserModel - serverKeyModel *models.ServerMasterKeyModel + fileModel *models.FileModel + fileShareModel *models.FileShareModel + keyFragmentModel *models.KeyFragmentModel + encryptionService *services.EncryptionService + activityLogModel *models.ActivityLogModel + rsService *services.ReedSolomonService + userModel *models.UserModel + serverKeyModel *models.ServerMasterKeyModel + twoFactorService *services.TwoFactorAuthService + emailService *services.SMTPEmailService + compressionService *services.CompressionService } func NewShareFileController( @@ -35,32 +40,46 @@ func NewShareFileController( rsService *services.ReedSolomonService, userModel *models.UserModel, serverKeyModel *models.ServerMasterKeyModel, + twoFactorService *services.TwoFactorAuthService, + emailService *services.SMTPEmailService, + compressionService *services.CompressionService, ) *ShareFileController { return &ShareFileController{ - fileModel: fileModel, - fileShareModel: fileShareModel, - keyFragmentModel: keyFragmentModel, - encryptionService: encryptionService, - activityLogModel: activityLogModel, - rsService: rsService, - userModel: userModel, - serverKeyModel: serverKeyModel, + fileModel: fileModel, + fileShareModel: fileShareModel, + keyFragmentModel: keyFragmentModel, + encryptionService: encryptionService, + activityLogModel: activityLogModel, + rsService: rsService, + userModel: userModel, + serverKeyModel: serverKeyModel, + twoFactorService: twoFactorService, + emailService: emailService, + compressionService: compressionService, } } type CreateShareRequest struct { - FileID uint `json:"file_id" binding:"required"` - Password string `json:"password" binding:"required,min=6"` - ExpiresAt *time.Time `json:"expires_at"` - MaxDownloads *int `json:"max_downloads"` + Password string `json:"password" binding:"required,min=6"` + ExpiresAt *time.Time `json:"expires_at"` + MaxDownloads *int `json:"max_downloads"` + ShareType models.ShareType `json:"share_type" binding:"required"` + Email string `json:"email,omitempty"` } type AccessShareRequest struct { Password string `json:"password" binding:"required"` + Email string `json:"email,omitempty"` +} + +type TwoFactorRequest struct { + Code string `json:"code" binding:"required"` + Password string `json:"password" binding:"required"` } func (c *ShareFileController) CreateShare(ctx *gin.Context) { - log.Printf("Received premium share creation request for file ID: %v", ctx.Param("id")) + log.Printf("Received advanced share creation request for file ID: %v", ctx.Param("id")) + var req CreateShareRequest if err := ctx.ShouldBindJSON(&req); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{ @@ -70,27 +89,31 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Validate expiry date if provided - if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now()) { + if req.ShareType == models.RecipientShare && req.Email == "" { ctx.JSON(http.StatusBadRequest, gin.H{ "status": "error", - "error": "Expiry date cannot be in the past", + "error": "Email required for recipient share", }) return } - user, exists := ctx.Get("user") - if !exists { - ctx.JSON(http.StatusUnauthorized, gin.H{ + if req.ExpiresAt != nil && req.ExpiresAt.Before(time.Now()) { + ctx.JSON(http.StatusBadRequest, gin.H{ "status": "error", - "error": "Unauthorized access", + "error": "Expiry date cannot be in the past", }) return } - currentUser := user.(*models.User) - // Get file and verify ownership - file, err := c.fileModel.GetFileForDownload(req.FileID, currentUser.ID) + user := ctx.MustGet("user").(*models.User) + + fileID, err := strconv.ParseUint(ctx.Param("id"), 10, 64) + if err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file ID"}) + return + } + + file, err := c.fileModel.GetFileForDownload(uint(fileID), user.ID) if err != nil { ctx.JSON(http.StatusNotFound, gin.H{ "status": "error", @@ -99,8 +122,7 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Derive user's KEK - kek, err := services.DeriveKeyEncryptionKey(currentUser.Password, currentUser.MasterKeySalt) + kek, err := services.DeriveKeyEncryptionKey(user.Password, user.MasterKeySalt) if err != nil { log.Printf("Failed to derive KEK: %v", err) ctx.JSON(http.StatusInternalServerError, gin.H{ @@ -110,11 +132,10 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Decrypt user's master key decryptedMasterKey, err := services.DecryptMasterKey( - currentUser.EncryptedMasterKey, + user.EncryptedMasterKey, kek, - currentUser.MasterKeyNonce, + user.MasterKeyNonce, ) if err != nil { log.Printf("Failed to decrypt master key: %v", err) @@ -125,10 +146,8 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Use first 32 bytes of decrypted master key userMasterKey := decryptedMasterKey[:32] - // Get fragments fragments, err := c.keyFragmentModel.GetUserFragmentsForFile(file.ID) if err != nil || len(fragments) == 0 { log.Printf("Failed to retrieve key fragments for file %d: %v", file.ID, err) @@ -139,11 +158,7 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - log.Printf("Creating premium share for file %d with %d fragments", file.ID, len(fragments)) userFragment := fragments[0] - log.Printf("Selected user fragment with index %d for sharing", userFragment.FragmentIndex) - - // Decrypt the fragment using master key decryptedFragment, err := services.DecryptMasterKey( userFragment.Data, userMasterKey, @@ -158,7 +173,6 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Encrypt fragment with share password encryptedFragment, err := c.encryptionService.EncryptKeyFragment( decryptedFragment, []byte(req.Password), @@ -172,15 +186,16 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Create premium share record share := &models.FileShare{ FileID: file.ID, - SharedBy: currentUser.ID, + SharedBy: user.ID, EncryptedKeyFragment: encryptedFragment, FragmentIndex: userFragment.FragmentIndex, ExpiresAt: req.ExpiresAt, MaxDownloads: req.MaxDownloads, IsActive: true, + ShareType: req.ShareType, + Email: req.Email, } if err := c.fileShareModel.CreateFileShare(share, req.Password); err != nil { @@ -192,152 +207,244 @@ func (c *ShareFileController) CreateShare(ctx *gin.Context) { return } - // Log premium share creation - if err := c.activityLogModel.LogActivity(&models.ActivityLog{ - UserID: currentUser.ID, + if req.ShareType == models.RecipientShare { + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "http://localhost:3000" + } + + shareURL := fmt.Sprintf("%s/protected-share/%s", baseURL, share.ShareLink) + + emailBody := fmt.Sprintf(`Hello, + +You have received a secure file share from %s. + +File: %s +Access Link: %s + +This link requires a password and email verification to access. +Please use the same email address this message was sent to when accessing the file. + +Best regards, +SafeSplit Team`, user.Username, file.OriginalName, shareURL) + + if err := c.emailService.SendEmail( + req.Email, + "Secure File Share Received", + emailBody, + ); err != nil { + log.Printf("Failed to send email: %v", err) + } + } + + c.activityLogModel.LogActivity(&models.ActivityLog{ + UserID: user.ID, ActivityType: "share", FileID: &file.ID, IPAddress: ctx.ClientIP(), Status: "success", - Details: fmt.Sprintf("Premium share created (Expires: %v, Max Downloads: %v)", req.ExpiresAt, req.MaxDownloads), - }); err != nil { - log.Printf("Failed to log share activity: %v", err) + Details: fmt.Sprintf("Created %s share with premium features", req.ShareType), + }) + + baseURL := os.Getenv("BASE_URL") + if baseURL == "" { + baseURL = "http://localhost:3000" } + // Determine share path based on type + sharePath := "/premium/share/" + if req.ShareType == models.RecipientShare { + sharePath = "/protected-share/" + } + + shareURL := fmt.Sprintf("%s%s%s", baseURL, sharePath, share.ShareLink) + ctx.JSON(http.StatusOK, gin.H{ "status": "success", "data": gin.H{ - "share_link": share.ShareLink, + "share_link": shareURL, + "raw_link": share.ShareLink, + "requires_2fa": req.ShareType == models.RecipientShare, }, }) } func (c *ShareFileController) AccessShare(ctx *gin.Context) { shareLink := ctx.Param("shareLink") - log.Printf("Received premium share access request for link: %s", shareLink) + log.Printf("Received share access request for link: %s", shareLink) + + if ctx.Request.Method == "GET" { + share, err := c.fileShareModel.GetShareByLink(shareLink) + if err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid share"}) + return + } + + file, err := c.fileModel.GetFileByID(share.FileID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "File not found"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": gin.H{ + "requires_password": true, + "requires_2fa": share.ShareType == models.RecipientShare, + "recipient_share": share.ShareType == models.RecipientShare, + "file_name": file.OriginalName, + "file_size": file.Size, + "mime_type": file.MimeType, + "created_at": share.CreatedAt, + "expires_at": share.ExpiresAt, + "download_count": share.DownloadCount, + "max_downloads": share.MaxDownloads, + }, + }) + return + } + var req AccessShareRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.JSON(http.StatusBadRequest, gin.H{ - "status": "error", - "error": "Invalid password", - }) + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) return } - // Validate share with premium features - share, err := c.fileShareModel.ValidateShare(shareLink, req.Password) + share, err := c.fileShareModel.GetShareByLink(shareLink) if err != nil { - ctx.JSON(http.StatusUnauthorized, gin.H{ - "status": "error", - "error": "Invalid share link or password", - }) + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid share"}) return } - // Check if share has expired - if share.ExpiresAt != nil && time.Now().After(*share.ExpiresAt) { - ctx.JSON(http.StatusForbidden, gin.H{ - "status": "error", - "error": "Share link has expired", - }) + // Get file info early for use in verification + file, err := c.fileModel.GetFileByID(share.FileID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "File not found"}) return } - // Check if maximum downloads reached - if share.MaxDownloads != nil && share.DownloadCount >= *share.MaxDownloads { - ctx.JSON(http.StatusForbidden, gin.H{ - "status": "error", - "error": "Maximum number of downloads reached", + if share.ShareType == models.RecipientShare { + // Validate password only - no email needed since we have the share + share, validationErr := c.fileShareModel.ValidateRecipientShare(shareLink, req.Password) + if validationErr != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{ + "status": "error", + "error": "Invalid password"}) + return + } + + // Send verification code to the email associated with the share + if err := c.twoFactorService.SendShareVerificationToken(share.ID, share.Email, file.OriginalName); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "status": "error", + "error": "Failed to send verification code"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "message": "Verification code sent to registered email", + "data": gin.H{ + "share_id": share.ID, + }, }) return } - // Check if share is still active - if !share.IsActive { - ctx.JSON(http.StatusForbidden, gin.H{ - "status": "error", - "error": "Share link is no longer active", - }) + if err := c.validatePremiumShare(share); err != nil { + ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } - log.Printf("Processing premium share access for link: %s", shareLink) + c.processFileAccess(ctx, share, req.Password) +} +func (c *ShareFileController) validatePremiumShare(share *models.FileShare) error { + if share.ExpiresAt != nil && time.Now().After(*share.ExpiresAt) { + return fmt.Errorf("share link has expired") + } + + if share.MaxDownloads != nil && share.DownloadCount >= *share.MaxDownloads { + return fmt.Errorf("maximum number of downloads reached") + } - // Get file metadata - file, err := c.fileModel.GetFileByID(share.FileID) - if err != nil { - log.Printf("Failed to get file %d: %v", share.FileID, err) - ctx.JSON(http.StatusNotFound, gin.H{ - "status": "error", - "error": "File not found", - }) + if !share.IsActive { + return fmt.Errorf("share link is no longer active") + } + + return nil +} + +func (c *ShareFileController) Verify2FAAndDownload(ctx *gin.Context) { + shareLink := ctx.Param("shareLink") + var req TwoFactorRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) return } - // Get server fragments - serverFragments, err := c.keyFragmentModel.GetServerFragmentsForFile(share.FileID) + share, err := c.fileShareModel.GetShareByLink(shareLink) if err != nil { - log.Printf("Failed to get server fragments: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process file access", - }) + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid share"}) return } - log.Printf("Retrieved %d server fragments", len(serverFragments)) + if share.ShareType != models.RecipientShare { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "2FA verification only required for recipient shares"}) + return + } - // Verify we have enough fragments - if len(serverFragments)+1 < int(file.Threshold) { - log.Printf("Insufficient fragments: have %d server + 1 shared, need %d", - len(serverFragments), file.Threshold) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Insufficient fragments to reconstruct file", - }) + if err := c.twoFactorService.VerifyToken(share.ID, req.Code); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid 2FA code"}) + return + } + + if err := c.validatePremiumShare(share); err != nil { + ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + + c.processFileAccess(ctx, share, req.Password) +} + +func (c *ShareFileController) processFileAccess(ctx *gin.Context, share *models.FileShare, password string) { + file, err := c.fileModel.GetFileByID(share.FileID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "File not found"}) + return + } + + serverFragments, err := c.keyFragmentModel.GetServerFragmentsForFile(share.FileID) + if err != nil || len(serverFragments)+1 < int(file.Threshold) { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Insufficient fragments"}) return } - // Get server key for decrypting server fragments serverKey, err := c.serverKeyModel.GetActive() if err != nil { - log.Printf("Failed to get server key: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process decryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Server key error"}) return } serverKeyData, err := c.serverKeyModel.GetServerKey(serverKey.KeyID) if err != nil { log.Printf("Failed to get server key data: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to get server key", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get server key"}) return } - // Decrypt shared fragment sharedDecryptedFragment, err := c.encryptionService.DecryptKeyFragment( share.EncryptedKeyFragment, - []byte(req.Password), + []byte(password), ) if err != nil { log.Printf("Failed to decrypt shared fragment: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to process file decryption", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to process file decryption"}) return } - // Prepare key shares array shares := make([]services.KeyShare, file.Threshold) usedIndices := make(map[int]bool) - // Add shared fragment first shares[0] = services.KeyShare{ Index: share.FragmentIndex, Value: hex.EncodeToString(sharedDecryptedFragment), @@ -345,11 +452,9 @@ func (c *ShareFileController) AccessShare(ctx *gin.Context) { usedIndices[share.FragmentIndex] = true log.Printf("Added shared fragment with index %d", share.FragmentIndex) - // Add server fragments sharesAdded := uint(1) for i := 0; i < len(serverFragments) && sharesAdded < file.Threshold; i++ { fragment := serverFragments[i] - if usedIndices[fragment.FragmentIndex] { continue } @@ -371,20 +476,15 @@ func (c *ShareFileController) AccessShare(ctx *gin.Context) { FragmentPath: fragment.FragmentPath, } usedIndices[fragment.FragmentIndex] = true - log.Printf("Added server fragment %d with index %d", i, fragment.FragmentIndex) sharesAdded++ } if sharesAdded < file.Threshold { log.Printf("Failed to get enough unique shares: have %d, need %d", sharesAdded, file.Threshold) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to get enough unique shares", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get enough unique shares"}) return } - // Get encrypted file data var encryptedData []byte var retrievalErr error @@ -398,14 +498,10 @@ func (c *ShareFileController) AccessShare(ctx *gin.Context) { if retrievalErr != nil { log.Printf("Failed to retrieve file data: %v", retrievalErr) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to read file data", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file data"}) return } - // Decrypt the file with encryption type decryptedData, err := c.encryptionService.DecryptFileWithType( encryptedData, file.EncryptionIV, @@ -416,38 +512,32 @@ func (c *ShareFileController) AccessShare(ctx *gin.Context) { ) if err != nil { log.Printf("Failed to decrypt file data: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to decrypt file", - }) + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to decrypt file"}) return } - // Log premium share access - if err := c.activityLogModel.LogActivity(&models.ActivityLog{ - UserID: share.SharedBy, - ActivityType: "download", - FileID: &file.ID, - IPAddress: ctx.ClientIP(), - Status: "success", - Details: fmt.Sprintf("Premium shared file download using %d fragments", file.Threshold), - }); err != nil { - log.Printf("Failed to log share download activity: %v", err) + if file.IsCompressed { + log.Printf("Decompressing data for file ID: %d", file.ID) + decryptedData, err = c.compressionService.Decompress(decryptedData) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to decompress file"}) + return + } } - // Update premium share status - should be before sending response - log.Printf("Incrementing download count for share ID %d", share.ID) if err := c.fileShareModel.IncrementDownloadCount(share.ID); err != nil { log.Printf("Failed to increment download count: %v", err) - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to update download count", - }) - return } - log.Printf("Successfully incremented download count for share ID %d", share.ID) - // Only proceed to send file if increment succeeded + c.activityLogModel.LogActivity(&models.ActivityLog{ + UserID: share.SharedBy, + ActivityType: "download", + FileID: &file.ID, + IPAddress: ctx.ClientIP(), + Status: "success", + Details: fmt.Sprintf("Download with %d fragments", file.Threshold), + }) + c.sendFileResponse(ctx, file, decryptedData) } @@ -457,7 +547,6 @@ func (c *ShareFileController) getShardedData(file *models.File) ([]byte, error) return nil, fmt.Errorf("failed to retrieve shards: %w", err) } - // Log shard information for debugging validShards := 0 for i, shard := range fileShards.Shards { if shard != nil { @@ -468,13 +557,11 @@ func (c *ShareFileController) getShardedData(file *models.File) ([]byte, error) } } - // Validate we have enough shards for reconstruction if !c.rsService.ValidateShards(fileShards.Shards, int(file.DataShardCount)) { return nil, fmt.Errorf("insufficient shards for reconstruction: have %d, need %d", validShards, file.DataShardCount) } - // Reconstruct file from shards reconstructed, err := c.rsService.ReconstructFile(fileShards.Shards, int(file.DataShardCount), int(file.ParityShardCount)) if err != nil { @@ -486,18 +573,19 @@ func (c *ShareFileController) getShardedData(file *models.File) ([]byte, error) } func (c *ShareFileController) sendFileResponse(ctx *gin.Context, file *models.File, data []byte) { - sanitizedFilename := strings.ReplaceAll(file.OriginalName, `"`, `\"`) - encodedFilename := url.QueryEscape(sanitizedFilename) - - ctx.Header("Access-Control-Expose-Headers", "Content-Disposition, Content-Type, Content-Length") - ctx.Header("Content-Description", "File Transfer") - ctx.Header("Content-Transfer-Encoding", "binary") - ctx.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, - sanitizedFilename, encodedFilename)) + escapedName := strings.ReplaceAll(file.OriginalName, `"`, `\"`) + utf8Name := url.PathEscape(file.OriginalName) + ctx.Header("Content-Disposition", fmt.Sprintf( + `attachment; filename="%s"; filename*=UTF-8''%s`, + escapedName, + utf8Name, + )) ctx.Header("Content-Type", file.MimeType) ctx.Header("Content-Length", fmt.Sprintf("%d", len(data))) - ctx.Header("X-Original-Filename", url.QueryEscape(file.OriginalName)) - + ctx.Header("X-Original-Filename", escapedName) + ctx.Header("Access-Control-Expose-Headers", "Content-Disposition, Content-Type, Content-Length, X-Original-Filename") + ctx.Header("Content-Description", "File Transfer") + ctx.Header("Content-Transfer-Encoding", "binary") log.Printf("Sending file response: %s (Size: %d bytes)", file.OriginalName, len(data)) ctx.Data(http.StatusOK, file.MimeType, data) } diff --git a/backend/controllers/PremiumUser/FileRecoveryController.go b/backend/controllers/PremiumUser/FileRecoveryController.go index c972053..7ca474f 100644 --- a/backend/controllers/PremiumUser/FileRecoveryController.go +++ b/backend/controllers/PremiumUser/FileRecoveryController.go @@ -20,7 +20,6 @@ func NewFileRecoveryController(fileModel *models.FileModel) *FileRecoveryControl // RecoverFile handles the recovery of a deleted file for premium users func (c *FileRecoveryController) RecoverFile(ctx *gin.Context) { - // Get authenticated user user, exists := ctx.Get("user") if !exists { ctx.JSON(http.StatusUnauthorized, gin.H{ @@ -76,7 +75,6 @@ func (c *FileRecoveryController) RecoverFile(ctx *gin.Context) { // ListRecoverableFiles returns a list of files that can be recovered func (c *FileRecoveryController) ListRecoverableFiles(ctx *gin.Context) { - // Get authenticated user user, exists := ctx.Get("user") if !exists { ctx.JSON(http.StatusUnauthorized, gin.H{ @@ -114,7 +112,6 @@ func (c *FileRecoveryController) ListRecoverableFiles(ctx *gin.Context) { return } - // Format response var response []gin.H for _, file := range files { response = append(response, gin.H{ diff --git a/backend/controllers/PremiumUser/FragmentBackupController.go b/backend/controllers/PremiumUser/FragmentBackupController.go deleted file mode 100644 index 1c1ff8b..0000000 --- a/backend/controllers/PremiumUser/FragmentBackupController.go +++ /dev/null @@ -1,92 +0,0 @@ -package PremiumUser - -import ( - "net/http" - "safesplit/models" - "strconv" - - "github.com/gin-gonic/gin" -) - -type FragmentController struct { - keyFragmentModel *models.KeyFragmentModel - fileModel *models.FileModel -} - -func NewFragmentController( - keyFragmentModel *models.KeyFragmentModel, - fileModel *models.FileModel, -) *FragmentController { - return &FragmentController{ - keyFragmentModel: keyFragmentModel, - fileModel: fileModel, - } -} - -func (c *FragmentController) GetUserFragments(ctx *gin.Context) { - // Get authenticated user - user, exists := ctx.Get("user") - if !exists { - ctx.JSON(http.StatusUnauthorized, gin.H{ - "status": "error", - "error": "Unauthorized access", - }) - return - } - currentUser, ok := user.(*models.User) - if !ok { - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Invalid user data", - }) - return - } - - // Parse file ID - fileID, err := strconv.ParseUint(ctx.Param("fileId"), 10, 32) - if err != nil { - ctx.JSON(http.StatusBadRequest, gin.H{ - "status": "error", - "error": "Invalid file ID", - }) - return - } - - // Verify file access - file, err := c.fileModel.GetFileForDownload(uint(fileID), currentUser.ID) - if err != nil { - ctx.JSON(http.StatusNotFound, gin.H{ - "status": "error", - "error": "File not found or access denied", - }) - return - } - - // Get user fragments - fragments, err := c.keyFragmentModel.GetUserFragmentsForFile(uint(fileID)) - if err != nil { - ctx.JSON(http.StatusInternalServerError, gin.H{ - "status": "error", - "error": "Failed to retrieve fragments", - }) - return - } - - // Format response - response := make([]gin.H, len(fragments)) - for i, fragment := range fragments { - response[i] = gin.H{ - "index": fragment.FragmentIndex, - "value": string(fragment.Data), - } - } - - ctx.JSON(http.StatusOK, gin.H{ - "status": "success", - "data": gin.H{ - "fileId": fileID, - "fileName": file.OriginalName, - "fragments": response, - }, - }) -} diff --git a/backend/controllers/PremiumUser/UpdateBillingController.go b/backend/controllers/PremiumUser/UpdateBillingController.go new file mode 100644 index 0000000..ad51ddd --- /dev/null +++ b/backend/controllers/PremiumUser/UpdateBillingController.go @@ -0,0 +1,111 @@ +package PremiumUser + +import ( + "net/http" + "safesplit/models" + "github.com/gin-gonic/gin" +) + +type UpdateBillingController struct { + billingModel *models.BillingModel +} + +func NewUpdateBillingController(billingModel *models.BillingModel) *UpdateBillingController { + return &UpdateBillingController{ + billingModel: billingModel, + } +} + +type UpdateBillingRequest struct { + BillingName string `json:"billing_name" binding:"required"` + BillingEmail string `json:"billing_email" binding:"required,email"` + BillingAddress string `json:"billing_address" binding:"required"` + CountryCode string `json:"country_code" binding:"required,len=2"` + DefaultPaymentMethod string `json:"default_payment_method" binding:"required,oneof=credit_card bank_account paypal"` + BillingCycle string `json:"billing_cycle" binding:"required,oneof=monthly yearly"` + Currency string `json:"currency" binding:"required,len=3"` +} + +// UpdateBillingDetails updates the billing profile for a premium user +func (c *UpdateBillingController) UpdateBillingDetails(ctx *gin.Context) { + user, exists := ctx.Get("user") + if !exists { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) + return + } + + premiumUser, ok := user.(*models.User) + if !ok { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "invalid user data"}) + return + } + + var req UpdateBillingRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid request data"}) + return + } + + // Get existing billing profile or create new one + profile, err := c.billingModel.GetUserBillingProfile(premiumUser.ID) + if err != nil { + // Create new profile if doesn't exist + profile = &models.BillingProfile{ + UserID: premiumUser.ID, + BillingStatus: models.BillingStatusActive, + } + } + + // Update profile with new details + profile.BillingName = req.BillingName + profile.BillingEmail = req.BillingEmail + profile.BillingAddress = req.BillingAddress + profile.CountryCode = req.CountryCode + profile.DefaultPaymentMethod = req.DefaultPaymentMethod + profile.BillingCycle = req.BillingCycle + profile.Currency = req.Currency + + var updateErr error + if profile.ID == 0 { + updateErr = c.billingModel.CreateBillingProfile(profile) + } else { + updateErr = c.billingModel.UpdateBillingProfile(profile) + } + + if updateErr != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to update billing details", + }) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "message": "Billing details updated successfully", + "data": profile, + }) +} + +// GetBillingDetails retrieves the current billing profile +func (c *UpdateBillingController) GetBillingDetails(ctx *gin.Context) { + user, exists := ctx.Get("user") + if !exists { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) + return + } + + premiumUser, ok := user.(*models.User) + if !ok { + ctx.JSON(http.StatusInternalServerError, gin.H{"error": "invalid user data"}) + return + } + + profile, err := c.billingModel.GetUserBillingProfile(premiumUser.ID) + if err != nil { + ctx.JSON(http.StatusNotFound, gin.H{"error": "billing profile not found"}) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "data": profile, + }) +} \ No newline at end of file diff --git a/backend/controllers/SuperAdmin/SuperAdminLoginController.go b/backend/controllers/SuperAdmin/SuperAdminLoginController.go index 76ef0d8..51dbbf1 100644 --- a/backend/controllers/SuperAdmin/SuperAdminLoginController.go +++ b/backend/controllers/SuperAdmin/SuperAdminLoginController.go @@ -12,6 +12,12 @@ type LoginController struct { userModel *models.UserModel } +type LoginRequest struct { + Email string `json:"email" binding:"required"` + Password string `json:"password" binding:"required"` + TwoFactorCode string `json:"two_factor_code"` +} + func NewLoginController(userModel *models.UserModel) *LoginController { return &LoginController{ userModel: userModel, @@ -19,24 +25,44 @@ func NewLoginController(userModel *models.UserModel) *LoginController { } func (c *LoginController) Login(ctx *gin.Context) { - var loginReq struct { - Email string `json:"email" binding:"required"` - Password string `json:"password" binding:"required"` - } - + var loginReq LoginRequest if err := ctx.ShouldBindJSON(&loginReq); err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // Authenticate super admin + // First authenticate super admin credentials user, err := c.userModel.AuthenticateSuperAdmin(loginReq.Email, loginReq.Password) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid super admin credentials"}) return } - // Generate token + // Always require 2FA for super admin + if loginReq.TwoFactorCode == "" { + // Initiate 2FA if code not provided + if err := c.userModel.InitiateEmailTwoFactor(user.ID); err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to send 2FA code", + }) + return + } + + ctx.JSON(http.StatusAccepted, gin.H{ + "message": "2FA required", + "requires_2fa": true, + "user_id": user.ID, + }) + return + } + + // Verify 2FA code + if err := c.userModel.VerifyEmailTwoFactor(user.ID, loginReq.TwoFactorCode); err != nil { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid 2FA code"}) + return + } + + // Generate token after successful 2FA token, err := config.GenerateToken(user.ID, user.Role) if err != nil { ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Error generating token"}) diff --git a/backend/controllers/SuperAdmin/ViewSysAdminController.go b/backend/controllers/SuperAdmin/ViewSysAdminController.go index 0064a99..720a992 100644 --- a/backend/controllers/SuperAdmin/ViewSysAdminController.go +++ b/backend/controllers/SuperAdmin/ViewSysAdminController.go @@ -18,14 +18,12 @@ func NewViewSysAdminController(userModel *models.UserModel) *ViewSysAdminControl } func (c *ViewSysAdminController) ListSysAdmins(ctx *gin.Context) { - // Get the authenticated super admin from context superAdmin, exists := ctx.Get("user") if !exists { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) return } - // Get list of sys admins sysAdmins, err := c.userModel.GetSysAdmins(superAdmin.(*models.User)) if err != nil { ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) diff --git a/backend/controllers/SysAdmin/ViewBillingRecordsController.go b/backend/controllers/SysAdmin/ViewBillingRecordsController.go new file mode 100644 index 0000000..d69f390 --- /dev/null +++ b/backend/controllers/SysAdmin/ViewBillingRecordsController.go @@ -0,0 +1,129 @@ +package SysAdmin + +import ( + "net/http" + "safesplit/models" + "github.com/gin-gonic/gin" +) + +type ViewBillingRecordsController struct { + billingModel *models.BillingModel +} + +func NewViewBillingRecordsController(billingModel *models.BillingModel) *ViewBillingRecordsController { + return &ViewBillingRecordsController{ + billingModel: billingModel, + } +} + +type ListBillingRecordsRequest struct { + Page int `form:"page,default=1" binding:"min=1"` + PageSize int `form:"page_size,default=10" binding:"min=1,max=100"` + Status string `form:"status"` + Cycle string `form:"cycle"` +} + +// GetBillingStats returns subscription statistics +func (c *ViewBillingRecordsController) GetBillingStats(ctx *gin.Context) { + admin, exists := ctx.Get("user") + if !exists { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) + return + } + + sysAdmin, ok := admin.(*models.User) + if !ok || (!sysAdmin.IsSysAdmin() && !sysAdmin.IsSuperAdmin()) { + ctx.JSON(http.StatusForbidden, gin.H{"error": "unauthorized access"}) + return + } + + stats, err := c.billingModel.GetSubscriptionStats() + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to fetch billing statistics", + }) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": stats, + }) +} + +func (c *ViewBillingRecordsController) GetAllBillingRecords(ctx *gin.Context) { + admin, exists := ctx.Get("user") + if !exists { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) + return + } + + sysAdmin, ok := admin.(*models.User) + if !ok || (!sysAdmin.IsSysAdmin() && !sysAdmin.IsSuperAdmin()) { + ctx.JSON(http.StatusForbidden, gin.H{"error": "unauthorized access"}) + return + } + + var req ListBillingRecordsRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + ctx.JSON(http.StatusBadRequest, gin.H{ + "error": "Invalid request parameters", + }) + return + } + + profiles, totalCount, err := c.billingModel.GetAllBillingRecords( + req.Status, + req.Cycle, + req.Page, + req.PageSize, + ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to fetch billing records", + }) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": gin.H{ + "records": profiles, + "meta": gin.H{ + "total": totalCount, + "page": req.Page, + "page_size": req.PageSize, + "total_pages": (totalCount + int64(req.PageSize) - 1) / int64(req.PageSize), + }, + }, + }) +} + +// GetExpiringSubscriptions retrieves subscriptions that will expire soon +func (c *ViewBillingRecordsController) GetExpiringSubscriptions(ctx *gin.Context) { + admin, exists := ctx.Get("user") + if !exists { + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authentication required"}) + return + } + + sysAdmin, ok := admin.(*models.User) + if !ok || (!sysAdmin.IsSysAdmin() && !sysAdmin.IsSuperAdmin()) { + ctx.JSON(http.StatusForbidden, gin.H{"error": "unauthorized access"}) + return + } + + // Get subscriptions expiring in next 7 days + expiringProfiles, err := c.billingModel.GetExpiringSubscriptions(7) + if err != nil { + ctx.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to fetch expiring subscriptions", + }) + return + } + + ctx.JSON(http.StatusOK, gin.H{ + "status": "success", + "data": expiringProfiles, + }) +} \ No newline at end of file diff --git a/backend/jobs/Subcription.go b/backend/jobs/Subcription.go deleted file mode 100644 index 3d7d58f..0000000 --- a/backend/jobs/Subcription.go +++ /dev/null @@ -1,45 +0,0 @@ -package jobs - -import ( - "time" - - "gorm.io/gorm" -) - -type SubscriptionHandler struct { - db *gorm.DB -} - -func NewSubscriptionHandler(db *gorm.DB) *SubscriptionHandler { - return &SubscriptionHandler{db: db} -} - -func (h *SubscriptionHandler) ProcessExpiredSubscriptions() error { - return h.db.Transaction(func(tx *gorm.DB) error { - now := time.Now() - - // Update all expired subscriptions regardless of storage - result := tx.Exec(` - UPDATE users u - INNER JOIN billing_profiles bp ON u.id = bp.user_id - SET u.subscription_status = 'free', - u.role = 'end_user', - bp.billing_status = 'failed', - bp.next_billing_date = NULL - WHERE bp.billing_status = 'cancelled' - AND bp.next_billing_date < ? - `, now) - - return result.Error - }) -} - -func StartSubscriptionScheduler(handler *SubscriptionHandler) { - ticker := time.NewTicker(24 * time.Hour) - go func() { - for range ticker.C { - if err := handler.ProcessExpiredSubscriptions(); err != nil { - } - } - }() -} diff --git a/backend/jobs/account_management.go b/backend/jobs/account_management.go new file mode 100644 index 0000000..0642175 --- /dev/null +++ b/backend/jobs/account_management.go @@ -0,0 +1,216 @@ +package jobs + +import ( + "log" + "time" + "gorm.io/gorm" +) + +// Constants for job scheduling and thresholds +const ( + InactivityThreshold = 90 * 24 * time.Hour // 90 days + DeletionThreshold = 180 * 24 * time.Hour // 180 days + + AccountProcessingInterval = 1 * time.Hour + SubscriptionInterval = 24 * time.Hour +) + +type User struct { + ID uint + IsActive bool + LastLogin *time.Time + UpdatedAt time.Time + Role string + AccountLockedUntil *time.Time + SubscriptionStatus string + StorageQuota int64 +} + +type JobManager struct { + db *gorm.DB + accountManager *AccountManager + subHandler *SubscriptionHandler +} + +func NewJobManager(db *gorm.DB) *JobManager { + return &JobManager{ + db: db, + accountManager: NewAccountManager(db), + subHandler: NewSubscriptionHandler(db), + } +} + +func (m *JobManager) StartAllJobs() { + m.StartAccountManagementJob() + m.StartSubscriptionJob() + log.Println("All scheduled jobs started") +} + +func (m *JobManager) StartAccountManagementJob() { + ticker := time.NewTicker(AccountProcessingInterval) + go func() { + for range ticker.C { + if err := m.accountManager.ProcessAccounts(); err != nil { + log.Printf("Error in account management job: %v", err) + } + } + }() + log.Println("Account management job started") +} + +func (m *JobManager) StartSubscriptionJob() { + ticker := time.NewTicker(SubscriptionInterval) + go func() { + for range ticker.C { + if err := m.subHandler.ProcessExpiredSubscriptions(); err != nil { + log.Printf("Error in subscription processing job: %v", err) + } + } + }() + log.Println("Subscription processing job started") +} + +type AccountManager struct { + db *gorm.DB +} + +func NewAccountManager(db *gorm.DB) *AccountManager { + return &AccountManager{db: db} +} + +func (m *AccountManager) ProcessAccounts() error { + if err := m.unlockExpiredAccounts(); err != nil { + log.Printf("Error unlocking accounts: %v", err) + } + + if err := m.deactivateInactiveAccounts(); err != nil { + log.Printf("Error deactivating accounts: %v", err) + } + + if err := m.deleteInactiveAccounts(); err != nil { + log.Printf("Error deleting accounts: %v", err) + } + + return nil +} + +func (m *AccountManager) unlockExpiredAccounts() error { + result := m.db.Exec(` + UPDATE users + SET account_locked_until = NULL, + failed_login_attempts = 0 + WHERE account_locked_until < ? + AND account_locked_until IS NOT NULL + AND is_active = true`, + time.Now(), + ) + + if result.Error != nil { + return result.Error + } + + if result.RowsAffected > 0 { + log.Printf("Unlocked %d accounts", result.RowsAffected) + } + return nil +} + +func (m *AccountManager) deactivateInactiveAccounts() error { + result := m.db.Exec(` + UPDATE users + SET is_active = false, + subscription_status = CASE + WHEN subscription_status = 'premium' THEN 'cancelled' + ELSE subscription_status + END, + storage_quota = CASE + WHEN subscription_status = 'premium' THEN 5368709120 -- 5GB + ELSE storage_quota + END + WHERE last_login < ? + AND is_active = true + AND role NOT IN ('sys_admin', 'super_admin')`, + time.Now().Add(-InactivityThreshold), + ) + + if result.Error != nil { + return result.Error + } + + if result.RowsAffected > 0 { + log.Printf("Deactivated %d inactive accounts", result.RowsAffected) + } + return nil +} + +func (m *AccountManager) deleteInactiveAccounts() error { + return m.db.Transaction(func(tx *gorm.DB) error { + // First, get the IDs of accounts to be deleted + var userIDs []uint + if err := tx.Model(&User{}). + Where("is_active = ? AND updated_at < ? AND role NOT IN ?", + false, + time.Now().Add(-DeletionThreshold), + []string{"sys_admin", "super_admin"}). + Pluck("id", &userIDs).Error; err != nil { + return err + } + + if len(userIDs) == 0 { + return nil + } + + deleteQueries := []string{ + "DELETE FROM password_history WHERE user_id IN (?)", + "DELETE FROM activity_logs WHERE user_id IN (?)", + "DELETE FROM billing_profiles WHERE user_id IN (?)", + "DELETE FROM key_fragments WHERE user_id IN (?)", + "DELETE FROM user_files WHERE user_id IN (?)", + "DELETE FROM users WHERE id IN (?)", + } + + for _, query := range deleteQueries { + if err := tx.Exec(query, userIDs).Error; err != nil { + return err + } + } + + log.Printf("Permanently deleted %d inactive accounts", len(userIDs)) + return nil + }) +} + +type SubscriptionHandler struct { + db *gorm.DB +} + +func NewSubscriptionHandler(db *gorm.DB) *SubscriptionHandler { + return &SubscriptionHandler{db: db} +} + +func (h *SubscriptionHandler) ProcessExpiredSubscriptions() error { + return h.db.Transaction(func(tx *gorm.DB) error { + result := tx.Exec(` + UPDATE users u + INNER JOIN billing_profiles bp ON u.id = bp.user_id + SET u.subscription_status = 'free', + u.role = 'end_user', + u.storage_quota = 5368709120, -- 5GB + bp.billing_status = 'failed', + bp.next_billing_date = NULL + WHERE bp.billing_status = 'cancelled' + AND bp.next_billing_date < ? + AND u.subscription_status = 'premium'`, + time.Now(), + ) + + if result.Error != nil { + return result.Error + } + + if result.RowsAffected > 0 { + log.Printf("Processed %d expired subscriptions", result.RowsAffected) + } + return nil + }) +} \ No newline at end of file diff --git a/backend/main.go b/backend/main.go index 8ab61eb..0d1acd1 100644 --- a/backend/main.go +++ b/backend/main.go @@ -11,6 +11,7 @@ import ( "safesplit/services" "strconv" "time" + "github.com/joho/godotenv" "github.com/gin-contrib/cors" @@ -56,8 +57,8 @@ func main() { twoFactorService := services.NewTwoFactorAuthService(emailService) // Initialize subscription handler and scheduler - subscriptionHandler := jobs.NewSubscriptionHandler(db) - jobs.StartSubscriptionScheduler(subscriptionHandler) + jobManager := jobs.NewJobManager(db) + jobManager.StartAllJobs() // Initialize distributed storage service storageService, err := services.NewDistributedStorageService(baseStoragePath, nodeCount) @@ -80,7 +81,6 @@ func main() { folderModel := models.NewFolderModel(db) fileShareModel := models.NewFileShareModel(db) keyFragmentModel := models.NewKeyFragmentModel(db, storageService) - keyRotationModel := models.NewKeyRotationModel(db) feedbackModel := models.NewFeedbackModel(db) // Initialize core services @@ -110,7 +110,7 @@ func main() { ) // Start cleanup scheduler for deleted files go func() { - ticker := time.NewTicker(24 * time.Hour) + ticker := time.NewTicker(24 * time.Hour) defer ticker.Stop() log.Println("Starting file cleanup scheduler...") @@ -137,7 +137,6 @@ func main() { folderModel, fileShareModel, keyFragmentModel, - keyRotationModel, serverMasterKeyModel, feedbackModel, encryptionService, @@ -145,6 +144,7 @@ func main() { compressionService, rsService, twoFactorService, + emailService, ) // Set up the Gin router with default middleware diff --git a/backend/middleware/auth_middleware.go b/backend/middleware/auth_middleware.go index 7c39730..44361ac 100644 --- a/backend/middleware/auth_middleware.go +++ b/backend/middleware/auth_middleware.go @@ -15,7 +15,7 @@ func AuthMiddleware(userModel *models.UserModel) gin.HandlerFunc { return func(c *gin.Context) { fmt.Printf("Starting auth middleware\n") authHeader := c.GetHeader("Authorization") - fmt.Println("Authorization header:", authHeader) // Debug log + fmt.Println("Authorization header:", authHeader) if authHeader == "" { fmt.Println("Missing Authorization header") @@ -33,7 +33,7 @@ func AuthMiddleware(userModel *models.UserModel) gin.HandlerFunc { } tokenStr := bearerToken[1] - fmt.Println("Token:", tokenStr) // Debug log + fmt.Println("Token:", tokenStr) token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { @@ -72,7 +72,7 @@ func AuthMiddleware(userModel *models.UserModel) gin.HandlerFunc { return } - fmt.Println("User ID from token claims:", uint(userID)) // Debug log + fmt.Println("User ID from token claims:", uint(userID)) user, err := userModel.FindByID(uint(userID)) if err != nil { @@ -89,7 +89,6 @@ func AuthMiddleware(userModel *models.UserModel) gin.HandlerFunc { return } - // Set both "user" and "user_id" in context fmt.Printf("Setting userID in context: %d\n", user.ID) c.Set("user", user) c.Set("user_id", user.ID) diff --git a/backend/models/activity_log.go b/backend/models/activity_log.go index b74e17f..3210a1b 100644 --- a/backend/models/activity_log.go +++ b/backend/models/activity_log.go @@ -48,10 +48,8 @@ func (m *ActivityLogModel) GetSystemLogs(filters map[string]interface{}, page, p query = query.Where("user_id = ?", userID) } - // Get total count query.Count(&total) - // Apply pagination offset := (page - 1) * pageSize err := query. Order("created_at DESC"). diff --git a/backend/models/billing.go b/backend/models/billing.go index f66e765..4b10eaf 100644 --- a/backend/models/billing.go +++ b/backend/models/billing.go @@ -78,18 +78,15 @@ func NewBillingModel(db *gorm.DB, userModel *UserModel) *BillingModel { // CreateBillingProfile creates a new billing profile for a user func (m *BillingModel) CreateBillingProfile(profile *BillingProfile) error { return m.db.Transaction(func(tx *gorm.DB) error { - // Check if user already has a billing profile var existingProfile BillingProfile if err := tx.Where("user_id = ?", profile.UserID).First(&existingProfile).Error; err == nil { return errors.New("user already has a billing profile") } - // Generate customer ID only if not provided if profile.CustomerID == "" { profile.CustomerID = fmt.Sprintf("CUST_%d_%s", profile.UserID, time.Now().Format("20060102")) } - // Create the billing profile if err := tx.Create(profile).Error; err != nil { return fmt.Errorf("failed to create billing profile: %v", err) } @@ -98,7 +95,6 @@ func (m *BillingModel) CreateBillingProfile(profile *BillingProfile) error { }) } -// updateBillingProfile updates or creates a billing profile func (m *BillingModel) UpdateBillingProfile(profile *BillingProfile) error { return m.db.Transaction(func(tx *gorm.DB) error { updates := map[string]interface{}{ @@ -127,7 +123,6 @@ func (m *BillingModel) UpdateBillingProfile(profile *BillingProfile) error { return nil }) } -// GetUserBillingProfile retrieves a user's billing profile func (m *BillingModel) GetUserBillingProfile(userID uint) (*BillingProfile, error) { var profile BillingProfile if err := m.db.Where("user_id = ?", userID).First(&profile).Error; err != nil { @@ -136,7 +131,6 @@ func (m *BillingModel) GetUserBillingProfile(userID uint) (*BillingProfile, erro return &profile, nil } -// GetUserWithBilling retrieves user and their billing information func (m *BillingModel) GetUserWithBilling(userID uint) (*UserBillingInfo, error) { var info UserBillingInfo @@ -154,8 +148,35 @@ func (m *BillingModel) GetUserWithBilling(userID uint) (*UserBillingInfo, error) return &info, nil } +func (m *BillingModel) GetAllBillingRecords(status, cycle string, page, pageSize int) ([]BillingProfile, int64, error) { + var profiles []BillingProfile + var totalCount int64 + + query := m.db.Model(&BillingProfile{}) + + if status != "" { + query = query.Where("billing_status = ?", status) + } + if cycle != "" { + query = query.Where("billing_cycle = ?", cycle) + } + + if err := query.Count(&totalCount).Error; err != nil { + return nil, 0, err + } + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize). + Preload("User"). + Find(&profiles).Error + + if err != nil { + return nil, 0, err + } + + return profiles, totalCount, nil +} -// UpdateSubscriptionStatus updates subscription and billing status func (m *BillingModel) UpdateSubscriptionStatus(userID uint, status string) error { return m.db.Transaction(func(tx *gorm.DB) error { var user User @@ -190,7 +211,6 @@ func (m *BillingModel) UpdateSubscriptionStatus(userID uint, status string) erro }) } -// CancelSubscription cancels user's subscription func (m *BillingModel) CancelSubscription(userID uint) error { return m.db.Transaction(func(tx *gorm.DB) error { var user User @@ -198,12 +218,10 @@ func (m *BillingModel) CancelSubscription(userID uint) error { return err } - // Check storage quota before scheduling downgrade if user.StorageUsed > DefaultStorageQuota { return ErrStorageExceedsQuota } - // Keep subscription active until next billing date var profile BillingProfile if err := tx.Where("user_id = ?", userID).First(&profile).Error; err != nil { return err @@ -216,7 +234,6 @@ func (m *BillingModel) CancelSubscription(userID uint) error { }) } -// GetSubscriptionStats gets billing statistics func (m *BillingModel) GetSubscriptionStats() (map[string]interface{}, error) { var stats map[string]interface{} @@ -232,7 +249,6 @@ func (m *BillingModel) GetSubscriptionStats() (map[string]interface{}, error) { return stats, err } -// GetExpiringSubscriptions gets subscriptions expiring soon func (m *BillingModel) GetExpiringSubscriptions(days int) ([]BillingProfile, error) { var profiles []BillingProfile expiryDate := time.Now().AddDate(0, 0, days) diff --git a/backend/models/feedback.go b/backend/models/feedback.go index 9ece2c8..f1c597f 100644 --- a/backend/models/feedback.go +++ b/backend/models/feedback.go @@ -1,8 +1,9 @@ package models import ( + "fmt" "time" - "fmt" + "gorm.io/gorm" ) @@ -13,41 +14,36 @@ const ( FeedbackTypeFeedback FeedbackType = "feedback" FeedbackTypeSuspiciousActivity FeedbackType = "suspicious_activity" - FeedbackStatusPending FeedbackStatus = "pending" - FeedbackStatusInReview FeedbackStatus = "in_review" - FeedbackStatusResolved FeedbackStatus = "resolved" + FeedbackStatusPending FeedbackStatus = "pending" + FeedbackStatusInReview FeedbackStatus = "in_review" + FeedbackStatusResolved FeedbackStatus = "resolved" ) -// Feedback represents the feedback table in the database type Feedback struct { ID uint `json:"id" gorm:"primaryKey"` - UserID uint `json:"user_id"` - Type FeedbackType `json:"type" gorm:"type:enum('feedback','suspicious_activity')"` - Subject string `json:"subject" gorm:"size:255;not null"` - Message string `json:"message" gorm:"type:text;not null"` - Details string `json:"details" gorm:"type:text"` + UserID uint `json:"user_id"` + Type FeedbackType `json:"type" gorm:"type:enum('feedback','suspicious_activity')"` + Subject string `json:"subject" gorm:"size:255;not null"` + Message string `json:"message" gorm:"type:text;not null"` + Details string `json:"details" gorm:"type:text"` Status FeedbackStatus `json:"status" gorm:"type:enum('pending','in_review','resolved');default:pending"` - CreatedAt time.Time `json:"created_at" gorm:"default:CURRENT_TIMESTAMP"` - UpdatedAt time.Time `json:"updated_at" gorm:"default:CURRENT_TIMESTAMP;ON UPDATE CURRENT_TIMESTAMP"` - User User `json:"user" gorm:"foreignKey:UserID"` + CreatedAt time.Time `json:"created_at" gorm:"default:CURRENT_TIMESTAMP"` + UpdatedAt time.Time `json:"updated_at" gorm:"default:CURRENT_TIMESTAMP;ON UPDATE CURRENT_TIMESTAMP"` + User User `json:"user" gorm:"foreignKey:UserID"` } -// FeedbackModel handles database operations for feedback type FeedbackModel struct { db *gorm.DB } -// NewFeedbackModel creates a new FeedbackModel instance func NewFeedbackModel(db *gorm.DB) *FeedbackModel { return &FeedbackModel{db: db} } -// Create adds a new feedback entry func (m *FeedbackModel) Create(feedback *Feedback) error { return m.db.Create(feedback).Error } -// GetByID retrieves a feedback entry by its ID func (m *FeedbackModel) GetByID(id uint) (*Feedback, error) { var feedback Feedback if err := m.db.Preload("User").First(&feedback, id).Error; err != nil { @@ -56,7 +52,6 @@ func (m *FeedbackModel) GetByID(id uint) (*Feedback, error) { return &feedback, nil } -// GetAllByUser retrieves all feedback entries for a specific user func (m *FeedbackModel) GetAllByUser(userID uint) ([]Feedback, error) { var feedbacks []Feedback if err := m.db.Where("user_id = ?", userID).Find(&feedbacks).Error; err != nil { @@ -65,7 +60,6 @@ func (m *FeedbackModel) GetAllByUser(userID uint) ([]Feedback, error) { return feedbacks, nil } -// GetAllByUserAndType retrieves all feedback entries for a specific user and type func (m *FeedbackModel) GetAllByUserAndType(userID uint, feedbackType FeedbackType) ([]Feedback, error) { var feedbacks []Feedback err := m.db.Where("user_id = ? AND type = ?", userID, feedbackType). @@ -77,51 +71,44 @@ func (m *FeedbackModel) GetAllByUserAndType(userID uint, feedbackType FeedbackTy return feedbacks, nil } -// GetAll retrieves all feedback entries with optional filters func (m *FeedbackModel) GetAll(filters map[string]interface{}, page, pageSize int) ([]Feedback, int64, error) { - var feedbacks []Feedback - var total int64 - - query := m.db.Model(&Feedback{}).Preload("User") - - // Apply filters - if feedbackType, ok := filters["type"].(FeedbackType); ok { - query = query.Where("type = ?", feedbackType) - } - if status, ok := filters["status"].(string); ok { - query = query.Where("status = ?", status) - } - if userID, ok := filters["user_id"].(uint); ok { - query = query.Where("user_id = ?", userID) - } - - // Get total count - query.Count(&total) - - // Apply pagination - offset := (page - 1) * pageSize - err := query. - Order("created_at DESC"). - Offset(offset). - Limit(pageSize). - Find(&feedbacks).Error - - return feedbacks, total, err + var feedbacks []Feedback + var total int64 + + query := m.db.Model(&Feedback{}).Preload("User") + + if feedbackType, ok := filters["type"].(FeedbackType); ok { + query = query.Where("type = ?", feedbackType) + } + if status, ok := filters["status"].(string); ok { + query = query.Where("status = ?", status) + } + if userID, ok := filters["user_id"].(uint); ok { + query = query.Where("user_id = ?", userID) + } + + query.Count(&total) + + offset := (page - 1) * pageSize + err := query. + Order("created_at DESC"). + Offset(offset). + Limit(pageSize). + Find(&feedbacks).Error + + return feedbacks, total, err } -// UpdateStatus updates the status of a feedback entry func (m *FeedbackModel) UpdateStatus(id uint, status FeedbackStatus) error { return m.db.Model(&Feedback{}). Where("id = ?", id). Update("status", status).Error } -// Delete removes a feedback entry func (m *FeedbackModel) Delete(id uint) error { return m.db.Delete(&Feedback{}, id).Error } -// GetByStatus retrieves all feedback entries with a specific status func (m *FeedbackModel) GetByStatus(status FeedbackStatus) ([]Feedback, error) { var feedbacks []Feedback if err := m.db.Where("status = ?", status).Find(&feedbacks).Error; err != nil { @@ -130,7 +117,6 @@ func (m *FeedbackModel) GetByStatus(status FeedbackStatus) ([]Feedback, error) { return feedbacks, nil } -// GetByType retrieves all feedback entries of a specific type func (m *FeedbackModel) GetByType(feedbackType FeedbackType) ([]Feedback, error) { var feedbacks []Feedback if err := m.db.Where("type = ?", feedbackType).Find(&feedbacks).Error; err != nil { @@ -139,7 +125,6 @@ func (m *FeedbackModel) GetByType(feedbackType FeedbackType) ([]Feedback, error) return feedbacks, nil } -// GetPendingCount returns the count of pending feedback entries func (m *FeedbackModel) GetPendingCount() (int64, error) { var count int64 err := m.db.Model(&Feedback{}). @@ -148,7 +133,6 @@ func (m *FeedbackModel) GetPendingCount() (int64, error) { return count, err } -// GetDateRangeCount returns the count of feedback entries within a date range func (m *FeedbackModel) GetDateRangeCount(startDate, endDate time.Time) (int64, error) { var count int64 err := m.db.Model(&Feedback{}). @@ -157,7 +141,6 @@ func (m *FeedbackModel) GetDateRangeCount(startDate, endDate time.Time) (int64, return count, err } -// UpdateStatusWithComment updates both the status and adds a comment to the feedback func (m *FeedbackModel) UpdateStatusWithComment(id uint, status FeedbackStatus, comment string) error { return m.db.Transaction(func(tx *gorm.DB) error { feedback, err := m.GetByID(id) @@ -165,24 +148,21 @@ func (m *FeedbackModel) UpdateStatusWithComment(id uint, status FeedbackStatus, return fmt.Errorf("feedback not found: %w", err) } - // Update status feedback.Status = status - // Append comment to details with timestamp timestamp := time.Now().Format(time.RFC3339) newDetails := fmt.Sprintf("%s\n[%s] Status changed to %s: %s", - feedback.Details, // Keep existing details + feedback.Details, timestamp, status, comment, ) feedback.Details = newDetails - // Save changes if err := tx.Save(feedback).Error; err != nil { return fmt.Errorf("failed to update feedback: %w", err) } return nil }) -} \ No newline at end of file +} diff --git a/backend/models/file.go b/backend/models/file.go index c77b882..869eba3 100644 --- a/backend/models/file.go +++ b/backend/models/file.go @@ -123,28 +123,27 @@ func (m *FileModel) CreateFileWithShards( serverKeyModel *ServerMasterKeyModel, ) error { return withTransactionRetry(m.db, 3, func(tx *gorm.DB) error { - // 1. Create file record + // 1. Update user storage first to lock the user row + if err := m.UpdateUserStorage(tx, file.UserID, file.Size); err != nil { + return fmt.Errorf("failed to update storage usage: %w", err) + } + + // 2. Create file record if err := m.CreateFile(tx, file); err != nil { return fmt.Errorf("failed to create file record: %w", err) } - // 2. Store shards + // 3. Store shards if err := m.rsService.StoreShards(file.ID, &services.FileShards{Shards: shards}); err != nil { return fmt.Errorf("failed to store shards: %w", err) } - // 3. Save key fragments + // 4. Save key fragments if err := keyFragmentModel.SaveKeyFragments(tx, file.ID, shares, file.UserID, serverKeyModel); err != nil { m.rsService.DeleteShards(file.ID) // clean up return fmt.Errorf("failed to save key fragments: %w", err) } - // 4. Update user storage - if err := m.UpdateUserStorage(tx, file.UserID, file.Size); err != nil { - m.rsService.DeleteShards(file.ID) - return fmt.Errorf("failed to update storage usage: %w", err) - } - // 5. Log activity activity := &ActivityLog{ UserID: file.UserID, @@ -159,7 +158,6 @@ func (m *FileModel) CreateFileWithShards( return fmt.Errorf("failed to log activity: %w", err) } - // If everything succeeded, return nil return nil }) } @@ -382,11 +380,9 @@ func (m *FileModel) DeleteFile(fileID, userID uint, ipAddress string) error { } // Keep shards for potential recovery - // We'll only delete physical files for non-sharded files if !file.IsSharded && file.FilePath != "" { if err := os.Remove(file.FilePath); err != nil && !os.IsNotExist(err) { log.Printf("Failed to delete file content - Path: %s, Error: %v", file.FilePath, err) - // Don't rollback here as the file might have been already moved/deleted log.Printf("Continuing deletion process despite file removal error") } } @@ -460,6 +456,44 @@ func (m *FileModel) ArchiveFile(fileID, userID uint, ipAddress string) error { return nil } +func (m *FileModel) UnarchiveFile(fileID, userID uint, ipAddress string) error { + tx := m.db.Begin() + + // Unarchive the file + result := tx.Model(&File{}). + Where("id = ? AND user_id = ? AND is_archived = ?", fileID, userID, true). + Update("is_archived", false) + + if result.Error != nil { + tx.Rollback() + return fmt.Errorf("failed to unarchive file: %w", result.Error) + } + + if result.RowsAffected == 0 { + tx.Rollback() + return fmt.Errorf("file not found or not archived") + } + + // Log activity + activity := &ActivityLog{ + UserID: userID, + ActivityType: "unarchive", + FileID: &fileID, + IPAddress: ipAddress, + Status: "success", + } + + if err := tx.Create(activity).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to log activity: %w", err) + } + + if err := tx.Commit().Error; err != nil { + return fmt.Errorf("failed to complete unarchive operation: %w", err) + } + + return nil +} // Storage management func (m *FileModel) GetUserStorageInfo(userID uint) (used int64, quota int64, err error) { @@ -573,7 +607,6 @@ func (m *FileModel) RecoverFile(fileID, userID uint) error { log.Printf("Starting file recovery process - File ID: %d, User ID: %d", fileID, userID) - // Get the deleted file var file File if err := tx.Where("id = ? AND user_id = ? AND is_deleted = ?", fileID, userID, true).First(&file).Error; err != nil { tx.Rollback() @@ -584,7 +617,6 @@ func (m *FileModel) RecoverFile(fileID, userID uint) error { log.Printf("Found deleted file - ID: %d, IsSharded: %v, Size: %d bytes", file.ID, file.IsSharded, file.Size) - // Verify storage space usedStorage, quota, err := m.GetUserStorageInfo(userID) if err != nil { tx.Rollback() @@ -599,10 +631,8 @@ func (m *FileModel) RecoverFile(fileID, userID uint) error { return fmt.Errorf("insufficient storage space for recovery") } - // Verify file data if file.IsSharded { log.Printf("Verifying shards for file %d", file.ID) - // Check shards integrity fileShards, err := m.rsService.RetrieveShards(file.ID, int(file.DataShardCount+file.ParityShardCount)) if err != nil { tx.Rollback() @@ -701,11 +731,11 @@ func (f *File) ValidateIVSize() error { var expectedSize int switch f.EncryptionType { case services.ChaCha20: - expectedSize = 24 // XChaCha20-Poly1305 + expectedSize = 24 case services.Twofish: - expectedSize = 12 // Twofish-GCM + expectedSize = 12 case services.StandardEncryption: - expectedSize = 16 // AES-GCM + expectedSize = 16 default: return fmt.Errorf("unsupported encryption type: %s", f.EncryptionType) } @@ -798,7 +828,6 @@ func (m *FileModel) PermanentlyDeleteFile(fileID, userID uint, ipAddress string) return nil } -// PermanentDeletionLog represents an audit log for permanent deletions type PermanentDeletionLog struct { ID uint `gorm:"primaryKey"` UserID uint `gorm:"not null"` diff --git a/backend/models/file_share.go b/backend/models/file_share.go index 5964968..e394694 100644 --- a/backend/models/file_share.go +++ b/backend/models/file_share.go @@ -1,234 +1,188 @@ package models import ( - "crypto/rand" - "encoding/base64" - "fmt" - "log" - "time" - - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" + "crypto/rand" + "encoding/base64" + "fmt" + "time" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +type ShareType string + +const ( + NormalShare ShareType = "normal" + RecipientShare ShareType = "recipient" ) type FileShare struct { - ID uint `json:"id" gorm:"primaryKey"` - FileID uint `json:"file_id"` - SharedBy uint `json:"shared_by"` - ShareLink string `json:"share_link" gorm:"unique"` - PasswordHash string `json:"-"` - PasswordSalt string `json:"-"` - EncryptedKeyFragment []byte `json:"-" gorm:"type:mediumblob"` - FragmentIndex int `json:"-" gorm:"not null"` - ExpiresAt *time.Time `json:"expires_at"` - MaxDownloads *int `json:"max_downloads"` - DownloadCount int `json:"download_count" gorm:"default:0"` - IsActive bool `json:"is_active" gorm:"default:true"` - CreatedAt time.Time `json:"created_at"` - File File `json:"file" gorm:"foreignKey:FileID"` + ID uint `json:"id" gorm:"primaryKey"` + FileID uint `json:"file_id"` + SharedBy uint `json:"shared_by"` + ShareLink string `json:"share_link" gorm:"unique"` + PasswordHash string `json:"-"` + PasswordSalt string `json:"-"` + EncryptedKeyFragment []byte `json:"-" gorm:"type:mediumblob"` + FragmentIndex int `json:"-" gorm:"not null"` + ExpiresAt *time.Time `json:"expires_at"` + MaxDownloads *int `json:"max_downloads"` + DownloadCount int `json:"download_count" gorm:"default:0"` + IsActive bool `json:"is_active" gorm:"default:true"` + CreatedAt time.Time `json:"created_at"` + File File `json:"file" gorm:"foreignKey:FileID"` + ShareType ShareType `json:"share_type" gorm:"type:varchar(20);default:'normal'"` + Email string `json:"email,omitempty"` } type FileShareModel struct { - db *gorm.DB + db *gorm.DB } func NewFileShareModel(db *gorm.DB) *FileShareModel { - return &FileShareModel{db: db} + return &FileShareModel{db: db} } func generateShareLink() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - return base64.URLEncoding.EncodeToString(bytes), nil + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(bytes), nil } -func (m *FileShareModel) CreateFileShareWithStatus(share *FileShare, password string) error { - // Generate password salt - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - share.PasswordSalt = base64.StdEncoding.EncodeToString(salt) - - // Hash password with salt - hashedPassword, err := bcrypt.GenerateFromPassword( - []byte(password+share.PasswordSalt), - bcrypt.DefaultCost, - ) - if err != nil { - return fmt.Errorf("failed to hash password: %w", err) - } - share.PasswordHash = string(hashedPassword) - - // Generate unique share link - shareLink, err := generateShareLink() - if err != nil { - return fmt.Errorf("failed to generate share link: %w", err) - } - share.ShareLink = shareLink - - // Start transaction - tx := m.db.Begin() - if tx.Error != nil { - return fmt.Errorf("failed to start transaction: %w", tx.Error) - } - - // Create share record within transaction - if err := tx.Create(share).Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to create share record: %w", err) - } - - // Update file's IsShared status - if err := tx.Model(&File{}).Where("id = ?", share.FileID).Update("is_shared", true).Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to update file status: %w", err) - } - - // Commit transaction - if err := tx.Commit().Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to commit transaction: %w", err) - } +func (m *FileShareModel) CreateFileShare(share *FileShare, password string) error { + if share.ShareType == RecipientShare && share.Email == "" { + return fmt.Errorf("email required for recipient share") + } + + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return fmt.Errorf("failed to generate salt: %w", err) + } + share.PasswordSalt = base64.StdEncoding.EncodeToString(salt) + + hashedPassword, err := bcrypt.GenerateFromPassword( + []byte(password+share.PasswordSalt), + bcrypt.DefaultCost, + ) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + share.PasswordHash = string(hashedPassword) + + shareLink, err := generateShareLink() + if err != nil { + return fmt.Errorf("failed to generate share link: %w", err) + } + share.ShareLink = shareLink + + tx := m.db.Begin() + if tx.Error != nil { + return fmt.Errorf("failed to start transaction: %w", tx.Error) + } + + if err := tx.Create(share).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to create share record: %w", err) + } + + if err := tx.Model(&File{}).Where("id = ?", share.FileID).Update("is_shared", true).Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to update file status: %w", err) + } + + return tx.Commit().Error +} - return nil +func (m *FileShareModel) ValidateShare(shareLink string, password string) (*FileShare, error) { + var share FileShare + if err := m.db.Where("share_link = ? AND is_active = ? AND share_type = ?", + shareLink, true, NormalShare).Preload("File").First(&share).Error; err != nil { + return nil, fmt.Errorf("share not found or inactive") + } + + if err := bcrypt.CompareHashAndPassword( + []byte(share.PasswordHash), + []byte(password+share.PasswordSalt), + ); err != nil { + return nil, fmt.Errorf("invalid password") + } + + return &share, nil } -func (m *FileShareModel) ValidateShareAccess(shareLink string, password string) (*FileShare, error) { +func (m *FileShareModel) ValidateRecipientShare(shareLink string, password string) (*FileShare, error) { var share FileShare - if err := m.db.Where("share_link = ? AND is_active = ?", shareLink, true). - Preload("File").First(&share).Error; err != nil { - return nil, fmt.Errorf("share not found or inactive") + if err := m.db.Where("share_link = ? AND is_active = ? AND share_type = ?", + shareLink, true, RecipientShare).Preload("File").First(&share).Error; err != nil { + return nil, fmt.Errorf("share not found or invalid") } - - // Check expiration + if share.ExpiresAt != nil && share.ExpiresAt.Before(time.Now()) { share.IsActive = false m.db.Save(&share) return nil, fmt.Errorf("share has expired") } - - // Check download limit + if share.MaxDownloads != nil && share.DownloadCount >= *share.MaxDownloads { share.IsActive = false m.db.Save(&share) return nil, fmt.Errorf("download limit exceeded") } - - // Verify password + if err := bcrypt.CompareHashAndPassword( []byte(share.PasswordHash), []byte(password+share.PasswordSalt), ); err != nil { return nil, fmt.Errorf("invalid password") } - - return &share, nil -} - -// CreateFileShare creates a basic file share with just password protection -func (m *FileShareModel) CreateFileShare(share *FileShare, password string) error { - // Generate password salt - salt := make([]byte, 16) - if _, err := rand.Read(salt); err != nil { - return fmt.Errorf("failed to generate salt: %w", err) - } - share.PasswordSalt = base64.StdEncoding.EncodeToString(salt) - - // Hash password with salt - hashedPassword, err := bcrypt.GenerateFromPassword( - []byte(password+share.PasswordSalt), - bcrypt.DefaultCost, - ) - if err != nil { - return fmt.Errorf("failed to hash password: %w", err) - } - share.PasswordHash = string(hashedPassword) - - // Generate unique share link - shareLink, err := generateShareLink() - if err != nil { - return fmt.Errorf("failed to generate share link: %w", err) - } - share.ShareLink = shareLink - - // Start transaction - tx := m.db.Begin() - if tx.Error != nil { - return fmt.Errorf("failed to start transaction: %w", tx.Error) - } - - // Create share record within transaction - if err := tx.Create(share).Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to create share record: %w", err) - } - - // Update file's IsShared status - if err := tx.Model(&File{}).Where("id = ?", share.FileID).Update("is_shared", true).Error; err != nil { - tx.Rollback() - return fmt.Errorf("failed to update file status: %w", err) - } - - return tx.Commit().Error -} - -// ValidateShare validates a share without checking expiry or download count -func (m *FileShareModel) ValidateShare(shareLink string, password string) (*FileShare, error) { - var share FileShare - if err := m.db.Where("share_link = ? AND is_active = ?", shareLink, true). - Preload("File").First(&share).Error; err != nil { - return nil, fmt.Errorf("share not found or inactive") - } - - // Verify password - if err := bcrypt.CompareHashAndPassword( - []byte(share.PasswordHash), - []byte(password+share.PasswordSalt), - ); err != nil { - return nil, fmt.Errorf("invalid password") - } - + return &share, nil + } + func (m *FileShareModel) ValidatePassword(shareLink string, password string) error { + var share FileShare + if err := m.db.Where("share_link = ?", shareLink).First(&share).Error; err != nil { + return fmt.Errorf("share not found") + } + + if err := bcrypt.CompareHashAndPassword( + []byte(share.PasswordHash), + []byte(password+share.PasswordSalt), + ); err != nil { + return fmt.Errorf("invalid password") + } + + return nil } func (m *FileShareModel) IncrementDownloadCount(shareID uint) error { - log.Printf("Starting IncrementDownloadCount for share ID %d", shareID) + tx := m.db.Begin() + if tx.Error != nil { + return fmt.Errorf("failed to start transaction: %w", tx.Error) + } + defer tx.Rollback() - // Start transaction - tx := m.db.Begin() - if tx.Error != nil { - log.Printf("Failed to start transaction: %v", tx.Error) - return fmt.Errorf("failed to start transaction: %w", tx.Error) - } - defer tx.Rollback() // rollback if not committed + result := tx.Model(&FileShare{}). + Where("id = ?", shareID). + Update("download_count", gorm.Expr("download_count + ?", 1)) - log.Printf("Started transaction for share ID %d", shareID) + if result.Error != nil { + return fmt.Errorf("failed to increment download count: %w", result.Error) + } - result := tx.Model(&FileShare{}). - Where("id = ?", shareID). - Update("download_count", gorm.Expr("download_count + ?", 1)) - - if result.Error != nil { - log.Printf("Error during update: %v", result.Error) - return fmt.Errorf("failed to increment download count: %w", result.Error) - } - - log.Printf("Update query executed, affected rows: %d", result.RowsAffected) - - if result.RowsAffected == 0 { - log.Printf("No rows affected for share ID %d", shareID) - return fmt.Errorf("no share found with ID %d", shareID) - } - - // Commit transaction - if err := tx.Commit().Error; err != nil { - log.Printf("Failed to commit transaction: %v", err) - return fmt.Errorf("failed to commit transaction: %w", err) - } + if result.RowsAffected == 0 { + return fmt.Errorf("no share found with ID %d", shareID) + } - log.Printf("Successfully committed download count increment for share ID %d", shareID) - return nil + return tx.Commit().Error } +func (m *FileShareModel) GetShareByLink(shareLink string) (*FileShare, error) { + var share FileShare + err := m.db.Where("share_link = ? AND is_active = ?", shareLink, true). + Preload("File").First(&share).Error + if err != nil { + return nil, fmt.Errorf("share not found or inactive") + } + return &share, nil +} \ No newline at end of file diff --git a/backend/models/key_rotation.go b/backend/models/key_rotation.go deleted file mode 100644 index 19b1525..0000000 --- a/backend/models/key_rotation.go +++ /dev/null @@ -1,162 +0,0 @@ -package models - -import ( - "fmt" - "time" - - "gorm.io/gorm" -) - -type RotationType string - -const ( - RotationTypeAutomatic RotationType = "automatic" - RotationTypeManual RotationType = "manual" - RotationTypeForced RotationType = "forced" -) -type KeyRotationHistory struct { - ID uint `json:"id" gorm:"primaryKey"` - UserID uint `json:"user_id"` - OldKeyVersion int `json:"old_key_version"` - NewKeyVersion int `json:"new_key_version"` - RotationType RotationType `json:"rotation_type" gorm:"type:enum('automatic','manual','forced','password_change')"` - RotatedAt time.Time `json:"rotated_at" gorm:"autoCreateTime"` -} - -type KeyRotationModel struct { - db *gorm.DB -} - -func NewKeyRotationModel(db *gorm.DB) *KeyRotationModel { - return &KeyRotationModel{db: db} -} - -// LogRotation records a key rotation event -func (m *KeyRotationModel) LogRotation(userID uint, oldVersion, newVersion int, rotationType RotationType) error { - rotation := KeyRotationHistory{ - UserID: userID, - OldKeyVersion: oldVersion, - NewKeyVersion: newVersion, - RotationType: rotationType, - } - - if err := m.db.Create(&rotation).Error; err != nil { - return fmt.Errorf("failed to log key rotation: %w", err) - } - - return nil -} - -// GetRotationHistory retrieves all rotation events for a user -func (m *KeyRotationModel) GetRotationHistory(userID uint) ([]KeyRotationHistory, error) { - var history []KeyRotationHistory - - err := m.db.Where("user_id = ?", userID). - Order("rotated_at DESC"). - Find(&history).Error - if err != nil { - return nil, fmt.Errorf("failed to retrieve rotation history: %w", err) - } - - return history, nil -} - -// GetLatestRotation gets the most recent key rotation event for a user -func (m *KeyRotationModel) GetLatestRotation(userID uint) (*KeyRotationHistory, error) { - var rotation KeyRotationHistory - - err := m.db.Where("user_id = ?", userID). - Order("rotated_at DESC"). - First(&rotation).Error - if err != nil { - if err == gorm.ErrRecordNotFound { - return nil, nil - } - return nil, fmt.Errorf("failed to retrieve latest rotation: %w", err) - } - - return &rotation, nil -} - -// CountRotationsByType counts rotations by type within a time period -func (m *KeyRotationModel) CountRotationsByType(userID uint, rotationType RotationType, since time.Time) (int64, error) { - var count int64 - - err := m.db.Model(&KeyRotationHistory{}). - Where("user_id = ? AND rotation_type = ? AND rotated_at >= ?", userID, rotationType, since). - Count(&count).Error - if err != nil { - return 0, fmt.Errorf("failed to count rotations: %w", err) - } - - return count, nil -} - -// CheckRotationNeeded determines if a key rotation is needed based on time since last rotation -func (m *KeyRotationModel) CheckRotationNeeded(userID uint, maxAge time.Duration) (bool, error) { - lastRotation, err := m.GetLatestRotation(userID) - if err != nil { - return false, err - } - - // If no rotation history exists, rotation is needed - if lastRotation == nil { - return true, nil - } - - // Calculate time since last rotation - timeSinceRotation := time.Since(lastRotation.RotatedAt) - return timeSinceRotation >= maxAge, nil -} - -// GetRotationsBetween gets all rotations between two timestamps -func (m *KeyRotationModel) GetRotationsBetween(userID uint, start, end time.Time) ([]KeyRotationHistory, error) { - var rotations []KeyRotationHistory - - err := m.db.Where("user_id = ? AND rotated_at BETWEEN ? AND ?", userID, start, end). - Order("rotated_at DESC"). - Find(&rotations).Error - if err != nil { - return nil, fmt.Errorf("failed to retrieve rotations: %w", err) - } - - return rotations, nil -} - -// GetUsersByRotationStatus gets users who need key rotation based on age threshold -func (m *KeyRotationModel) GetUsersByRotationStatus(maxAge time.Duration) ([]uint, error) { - var userIDs []uint - threshold := time.Now().Add(-maxAge) - - // Subquery to get latest rotation per user - latestRotations := m.db.Table("key_rotation_history"). - Select("user_id, MAX(rotated_at) as last_rotation"). - Group("user_id") - - // Get users whose last rotation is older than threshold or who have no rotations - err := m.db.Table("users"). - Select("users.id"). - Joins("LEFT JOIN (?) as latest_rotations ON users.id = latest_rotations.user_id", latestRotations). - Where("latest_rotations.last_rotation < ? OR latest_rotations.last_rotation IS NULL", threshold). - Pluck("users.id", &userIDs).Error - if err != nil { - return nil, fmt.Errorf("failed to get users needing rotation: %w", err) - } - - return userIDs, nil -} - -// ValidateKeyVersion checks if a given key version matches the latest version for a user -func (m *KeyRotationModel) ValidateKeyVersion(userID uint, version int) (bool, error) { - latestRotation, err := m.GetLatestRotation(userID) - if err != nil { - return false, err - } - - if latestRotation == nil { - // If no rotation history, version should be 1 - return version == 1, nil - } - - return version == latestRotation.NewKeyVersion, nil -} diff --git a/backend/models/password_history.go b/backend/models/password_history.go index 1b6856a..b715b6b 100644 --- a/backend/models/password_history.go +++ b/backend/models/password_history.go @@ -2,7 +2,7 @@ package models import ( "time" - + "fmt" "gorm.io/gorm" ) @@ -14,7 +14,6 @@ type PasswordHistory struct { User User `json:"-" gorm:"foreignKey:UserID"` } -// TableName overrides the default table name used by GORM. func (PasswordHistory) TableName() string { return "password_history" } @@ -27,7 +26,6 @@ func NewPasswordHistoryModel(db *gorm.DB) *PasswordHistoryModel { return &PasswordHistoryModel{db: db} } -// AddEntry adds a new password history entry func (m *PasswordHistoryModel) AddEntry(userID uint, passwordHash string) error { entry := &PasswordHistory{ UserID: userID, @@ -36,7 +34,6 @@ func (m *PasswordHistoryModel) AddEntry(userID uint, passwordHash string) error return m.db.Create(entry).Error } -// GetRecentPasswords retrieves the most recent password hashes for a user func (m *PasswordHistoryModel) GetRecentPasswords(userID uint, limit int) ([]string, error) { var entries []PasswordHistory if err := m.db.Where("user_id = ?", userID). @@ -52,51 +49,21 @@ func (m *PasswordHistoryModel) GetRecentPasswords(userID uint, limit int) ([]str } return hashes, nil } +func (m *PasswordHistoryModel) IsPasswordReused(userID uint, newPasswordHash string) (bool, error) { + recentPasswords, err := m.GetRecentPasswords(userID, 5) + if err != nil { + return false, fmt.Errorf("failed to get recent passwords: %w", err) + } -// CleanupOldEntries removes password history entries older than the specified duration -func (m *PasswordHistoryModel) CleanupOldEntries(userID uint, olderThan time.Duration) error { - cutoffTime := time.Now().Add(-olderThan) - return m.db.Where("user_id = ? AND changed_at < ?", userID, cutoffTime). - Delete(&PasswordHistory{}).Error -} + for _, oldHash := range recentPasswords { + if oldHash == newPasswordHash { + return true, nil + } + } -// CountUserEntries counts the number of password history entries for a user -func (m *PasswordHistoryModel) CountUserEntries(userID uint) (int64, error) { - var count int64 - err := m.db.Model(&PasswordHistory{}). - Where("user_id = ?", userID). - Count(&count).Error - return count, err + return false, nil } -func (m *PasswordHistoryModel) IsPasswordReused(userID uint, newPasswordHash string) (bool, error) { - recentPasswords, err := m.GetRecentPasswords(userID, 5) // Check last 5 passwords - if err != nil { - return false, err - } - for _, oldHash := range recentPasswords { - if oldHash == newPasswordHash { - return true, nil - } - } - return false, nil -} -func (m *PasswordHistoryModel) ScheduleCleanup(duration time.Duration) { - ticker := time.NewTicker(24 * time.Hour) // Run daily - go func() { - for range ticker.C { - // Get all users - var users []User - if err := m.db.Find(&users).Error; err != nil { - continue - } - // Clean up old entries for each user - for _, user := range users { - _ = m.CleanupOldEntries(user.ID, duration) - } - } - }() -} diff --git a/backend/models/server_master_key.go b/backend/models/server_master_key.go index 026148c..8e195c6 100644 --- a/backend/models/server_master_key.go +++ b/backend/models/server_master_key.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" "safesplit/utils" "time" @@ -14,7 +13,7 @@ import ( type ServerMasterKey struct { ID uint `json:"id" gorm:"primaryKey"` KeyID string `json:"key_id" gorm:"type:varchar(64);unique;not null"` - EncryptedKey []byte `json:"-" gorm:"type:binary(64);not null"` + EncryptedKey []byte `json:"-" gorm:"type:binary(32);not null"` KeyNonce []byte `json:"-" gorm:"type:binary(16);not null"` IsActive bool `json:"is_active" gorm:"default:true"` CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` @@ -40,51 +39,40 @@ func generateKeyID() (string, error) { // Initialize generates and stores the first server master key if none exists func (m *ServerMasterKeyModel) Initialize() error { - // Check if there's already an active key var count int64 if err := m.db.Model(&ServerMasterKey{}).Where("is_active = ?", true).Count(&count).Error; err != nil { return fmt.Errorf("failed to check existing keys: %w", err) } if count > 0 { - return nil // Server key already exists + return nil } - // Generate a new 32-byte master key masterKey := make([]byte, 32) if _, err := rand.Read(masterKey); err != nil { return fmt.Errorf("failed to generate master key: %w", err) } - // For 64-byte storage, pad with zeros - paddedKey := make([]byte, 64) - copy(paddedKey, masterKey) - keyID, err := generateKeyID() if err != nil { return fmt.Errorf("failed to generate key ID: %w", err) } - // Generate 16-byte nonce - nonce := make([]byte, 16) - if _, err := rand.Read(nonce); err != nil { + nonce, err := utils.GenerateNonce() + if err != nil { return fmt.Errorf("failed to generate nonce: %w", err) } now := time.Now() serverKey := &ServerMasterKey{ KeyID: keyID, - EncryptedKey: paddedKey, // Use padded 64-byte key - KeyNonce: nonce, // 16-byte nonce + EncryptedKey: masterKey, + KeyNonce: nonce, IsActive: true, ActivatedAt: &now, } - if err := m.db.Create(serverKey).Error; err != nil { - return fmt.Errorf("failed to store server master key: %w", err) - } - - return nil + return m.db.Create(serverKey).Error } // GetServerKey retrieves and processes the server key for encryption @@ -94,26 +82,11 @@ func (m *ServerMasterKeyModel) GetServerKey(keyID string) ([]byte, error) { return nil, fmt.Errorf("failed to get server key: %w", err) } - // Add debug logging - log.Printf("Retrieved key length: %d bytes", len(key.EncryptedKey)) - log.Printf("Raw key bytes: %v", key.EncryptedKey) - - // Handle hex-encoded string if that's what we're getting - if len(key.EncryptedKey) == 64 { - // Use first 32 bytes - log.Printf("Using first 32 bytes of 64-byte key") - return key.EncryptedKey[:32], nil - } - - // For any other length, try to decode if it's hex-encoded - if decoded, err := hex.DecodeString(string(key.EncryptedKey)); err == nil { - if len(decoded) >= 32 { - log.Printf("Decoded hex string to %d bytes, using first 32", len(decoded)) - return decoded[:32], nil - } + if len(key.EncryptedKey) != 32 { + return nil, fmt.Errorf("invalid key length: got %d, expected 32 bytes", len(key.EncryptedKey)) } - return nil, fmt.Errorf("invalid server key length in database: got %d bytes, need 64 for raw or hex-encoded key", len(key.EncryptedKey)) + return key.EncryptedKey, nil } // GetActive retrieves the current active server master key @@ -125,56 +98,6 @@ func (m *ServerMasterKeyModel) GetActive() (*ServerMasterKey, error) { return &key, nil } -// Rotate generates a new master key and retires the old one -func (m *ServerMasterKeyModel) Rotate() error { - return m.db.Transaction(func(tx *gorm.DB) error { - // Get current active key - var currentKey ServerMasterKey - if err := tx.Where("is_active = ? AND retired_at IS NULL", true).First(¤tKey).Error; err != nil { - return fmt.Errorf("failed to get current server key: %w", err) - } - - // Generate new 64-byte key - masterKey := make([]byte, 64) - if _, err := rand.Read(masterKey); err != nil { - return fmt.Errorf("failed to generate master key: %w", err) - } - - nonce, err := utils.GenerateNonce() - if err != nil { - return fmt.Errorf("failed to generate nonce: %w", err) - } - - keyID, err := generateKeyID() - if err != nil { - return fmt.Errorf("failed to generate key ID: %w", err) - } - - now := time.Now() - newKey := &ServerMasterKey{ - KeyID: keyID, - EncryptedKey: masterKey, - KeyNonce: nonce, - IsActive: true, - ActivatedAt: &now, - } - - if err := tx.Create(newKey).Error; err != nil { - return fmt.Errorf("failed to create new key: %w", err) - } - - // Retire old key - if err := tx.Model(¤tKey).Updates(map[string]interface{}{ - "is_active": false, - "retired_at": now, - }).Error; err != nil { - return fmt.Errorf("failed to retire old key: %w", err) - } - - return nil - }) -} - // GetByID retrieves a specific server master key by its ID func (m *ServerMasterKeyModel) GetByID(keyID string) (*ServerMasterKey, error) { var key ServerMasterKey diff --git a/backend/models/user.go b/backend/models/user.go index 52be796..7bb44cb 100644 --- a/backend/models/user.go +++ b/backend/models/user.go @@ -7,6 +7,7 @@ import ( "safesplit/services" "safesplit/utils" "strconv" + "strings" "time" "golang.org/x/crypto/bcrypt" @@ -157,19 +158,33 @@ func (m *UserModel) Authenticate(email, password string) (*User, error) { // Check if account is locked if user.AccountLockedUntil != nil && user.AccountLockedUntil.After(time.Now()) { - return nil, fmt.Errorf("account locked until %v", user.AccountLockedUntil) + remainingTime := user.AccountLockedUntil.Sub(time.Now()) + return nil, fmt.Errorf("account locked for %d minutes due to too many failed attempts", int(remainingTime.Minutes())) } // Verify password if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - if err := m.handleFailedLogin(&user); err != nil { - return nil, err + lockErr := m.handleFailedLogin(&user) + if lockErr != nil && strings.Contains(lockErr.Error(), "account locked") { + return nil, lockErr + } + remainingAttempts := 5 - user.FailedLoginAttempts + if remainingAttempts > 0 { + return nil, fmt.Errorf("invalid credentials - %d attempts remaining", remainingAttempts) } return nil, errors.New("invalid credentials") } - // Reset failed attempts and update login time - return m.handleSuccessfulLogin(&user) + updatedUser, err := m.handleSuccessfulLogin(&user) + if err != nil { + return nil, fmt.Errorf("failed to update login status: %w", err) + } + + if updatedUser.Role == RoleSysAdmin || updatedUser.TwoFactorEnabled { + return updatedUser, nil + } + + return updatedUser, nil } // handleSuccessfulLogin updates user state after successful login @@ -180,7 +195,7 @@ func (m *UserModel) handleSuccessfulLogin(user *User) (*User, error) { user.LastLogin = &now if err := m.db.Save(user).Error; err != nil { - return nil, err + return nil, fmt.Errorf("failed to update login state: %w", err) } return user, nil @@ -190,12 +205,22 @@ func (m *UserModel) handleSuccessfulLogin(user *User) (*User, error) { func (m *UserModel) handleFailedLogin(user *User) error { user.FailedLoginAttempts++ + var lockTime *time.Time if user.FailedLoginAttempts >= 5 { - lockTime := time.Now().Add(30 * time.Minute) - user.AccountLockedUntil = &lockTime + t := time.Now().Add(30 * time.Minute) + lockTime = &t + user.AccountLockedUntil = lockTime } - return m.db.Save(user).Error + if err := m.db.Save(user).Error; err != nil { + return fmt.Errorf("failed to update login attempts: %w", err) + } + + if lockTime != nil { + return fmt.Errorf("account locked until %v due to too many failed attempts", lockTime) + } + + return nil } // UpdateMasterKey updates the user's master key material @@ -231,40 +256,6 @@ func (u *User) UpdateMasterKey(db *gorm.DB, newEncryptedKey []byte) error { return nil } -// RotateMasterKey performs a key rotation operation -func (m *UserModel) RotateMasterKey(userID uint, newEncryptedKey []byte, rotationType RotationType, keyRotationModel *KeyRotationModel) error { - return m.db.Transaction(func(tx *gorm.DB) error { - // Get user and verify existence - var user User - if err := tx.First(&user, userID).Error; err != nil { - return fmt.Errorf("user not found: %w", err) - } - - // Verify rotation type is valid - switch rotationType { - case RotationTypeAutomatic, RotationTypeManual, RotationTypeForced: - // Valid rotation type - default: - return fmt.Errorf("invalid rotation type: %s", rotationType) - } - - // Store old version for logging - oldVersion := user.MasterKeyVersion - - // Update the master key - if err := user.UpdateMasterKey(tx, newEncryptedKey); err != nil { - return fmt.Errorf("failed to update master key: %w", err) - } - - // Log the rotation using the KeyRotationModel - if err := keyRotationModel.LogRotation(userID, oldVersion, user.MasterKeyVersion, rotationType); err != nil { - return fmt.Errorf("failed to log rotation: %w", err) - } - - return nil - }) -} - func (m *UserModel) AuthenticateSuperAdmin(email, password string) (*User, error) { var user User if err := m.db.Where("email = ? AND is_active = ? AND role = ?", @@ -272,20 +263,29 @@ func (m *UserModel) AuthenticateSuperAdmin(email, password string) (*User, error return nil, errors.New("invalid credentials") } - // Check if account is locked if user.AccountLockedUntil != nil && user.AccountLockedUntil.After(time.Now()) { - return nil, fmt.Errorf("account locked until %v", user.AccountLockedUntil) + remainingTime := user.AccountLockedUntil.Sub(time.Now()) + return nil, fmt.Errorf("account locked for %d minutes due to too many failed attempts", int(remainingTime.Minutes())) } - // Verify password if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - if err := m.handleFailedLogin(&user); err != nil { - return nil, err + lockErr := m.handleFailedLogin(&user) + if lockErr != nil && strings.Contains(lockErr.Error(), "account locked") { + return nil, lockErr + } + remainingAttempts := 5 - user.FailedLoginAttempts + if remainingAttempts > 0 { + return nil, fmt.Errorf("invalid credentials - %d attempts remaining", remainingAttempts) } return nil, errors.New("invalid credentials") } - return m.handleSuccessfulLogin(&user) + updatedUser, err := m.handleSuccessfulLogin(&user) + if err != nil { + return nil, fmt.Errorf("failed to update login status: %w", err) + } + + return updatedUser, nil } // FindByEmail retrieves a user by their email @@ -346,76 +346,36 @@ func (u *User) UpdateSubscription(db *gorm.DB, status string) error { return db.Save(u).Error } -// ResetPassword updates the user's password and optionally rotates master key -func (m *UserModel) ResetPassword(userID uint, currentPassword, newPassword string, newEncryptedMasterKey []byte, passwordHistoryModel *PasswordHistoryModel) error { - return m.db.Transaction(func(tx *gorm.DB) error { - var user User - if err := tx.First(&user, userID).Error; err != nil { - return errors.New("user not found") - } - - // Verify current password - if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentPassword)); err != nil { - return errors.New("current password is incorrect") - } +// Create Sys admin account +func (m *UserModel) CreateSysAdmin(creator *User, newAdmin *User) (*User, error) { + if !creator.IsSuperAdmin() { + return nil, errors.New("unauthorized: only super admins can create system administrators") + } - // Store the old password in history - if err := passwordHistoryModel.AddEntry(user.ID, user.Password); err != nil { - return err - } + err := m.db.Transaction(func(tx *gorm.DB) error { + newAdmin.Role = RoleSysAdmin - // Update password and related fields - now := time.Now() - updates := map[string]interface{}{ - "password": string(newPassword), - "last_password_change": now, - "force_password_change": false, + if err := tx.Create(newAdmin).Error; err != nil { + return fmt.Errorf("failed to create system administrator: %v", err) } - // Update master key if provided - if newEncryptedMasterKey != nil && len(newEncryptedMasterKey) > 0 { - nonce, err := utils.GenerateNonce() - if err != nil { - return err - } - - updates["encrypted_master_key"] = newEncryptedMasterKey - updates["master_key_nonce"] = nonce - updates["master_key_version"] = user.MasterKeyVersion + 1 - updates["key_last_rotated"] = now - - // Log key rotation - rotation := KeyRotationHistory{ - UserID: userID, - OldKeyVersion: user.MasterKeyVersion, - NewKeyVersion: user.MasterKeyVersion + 1, - RotationType: RotationTypeAutomatic, - } - - if err := tx.Create(&rotation).Error; err != nil { - return fmt.Errorf("failed to log key rotation: %w", err) - } + if err := tx.Model(newAdmin).Update("two_factor_enabled", true).Error; err != nil { + return fmt.Errorf("failed to enable 2FA: %v", err) } - return tx.Model(&user).Updates(updates).Error + return nil }) -} -// Create Sys admin account -func (m *UserModel) CreateSysAdmin(creator *User, newAdmin *User) (*User, error) { - if !creator.IsSuperAdmin() { - return nil, errors.New("unauthorized: only super admins can create system administrators") + if err != nil { + return nil, err } - // Ensure the new user is created as a sys_admin - newAdmin.Role = RoleSysAdmin - - // Create the new admin user - if err := m.db.Create(newAdmin).Error; err != nil { - return nil, fmt.Errorf("failed to create system administrator: %v", err) + var createdAdmin User + if err := m.db.First(&createdAdmin, newAdmin.ID).Error; err != nil { + return nil, fmt.Errorf("failed to load created admin: %v", err) } - return newAdmin, nil + return &createdAdmin, nil } // View Sys admin account @@ -758,6 +718,31 @@ func (m *UserModel) ResetPasswordWithFragments( return fmt.Errorf("current password is incorrect") } + // Hash new password FIRST to check for reuse + hashedNewPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + + // Check if the new password matches any of the recent passwords + recentPasswords, err := passwordHistoryModel.GetRecentPasswords(userID, 5) + if err != nil { + return fmt.Errorf("failed to check password history: %w", err) + } + + // Compare new password hash with recent password hashes + for _, oldHash := range recentPasswords { + if err := bcrypt.CompareHashAndPassword([]byte(oldHash), []byte(newPassword)); err == nil { + // If CompareHashAndPassword returns nil, the password matches + return errors.New("Cannot reuse any of your last 5 passwords") + } + } + + // Store current password in history BEFORE updating to new one + if err := passwordHistoryModel.AddEntry(user.ID, user.Password); err != nil { + return fmt.Errorf("failed to store password history: %w", err) + } + // Get original encrypted master key originalEncryptedKey := user.EncryptedMasterKey log.Printf("Current encrypted master key: %x", originalEncryptedKey) @@ -774,12 +759,7 @@ func (m *UserModel) ResetPasswordWithFragments( return fmt.Errorf("failed to decrypt master key: %w", err) } - // Hash new password and derive new KEK - hashedNewPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) - if err != nil { - return fmt.Errorf("failed to hash new password: %w", err) - } - + // Derive new KEK using already hashed password newKEK, err := services.DeriveKeyEncryptionKey(string(hashedNewPassword), user.MasterKeySalt) if err != nil { return fmt.Errorf("failed to derive new KEK: %w", err) @@ -799,6 +779,11 @@ func (m *UserModel) ResetPasswordWithFragments( log.Printf("New encrypted master key: %x", newEncryptedMasterKey) + // Store old password in history BEFORE updating + if err := passwordHistoryModel.AddEntry(user.ID, user.Password); err != nil { + return fmt.Errorf("failed to store password history: %w", err) + } + // Process fragments files, err := fileModel.ListAllUserFiles(userID) if err != nil { @@ -811,49 +796,44 @@ func (m *UserModel) ResetPasswordWithFragments( for _, file := range files { fragments, err := keyFragmentModel.GetUserFragmentsForFile(file.ID) if err != nil { + if err.Error() == "record not found" { + log.Printf("Warning: skipping missing fragments for file %d", file.ID) + continue + } return fmt.Errorf("failed to get key fragments for file %d: %w", file.ID, err) } for _, fragment := range fragments { log.Printf("Processing fragment %d for file %d", fragment.FragmentIndex, file.ID) - log.Printf("Fragment data before decryption: %x", fragment.Data) - log.Printf("Fragment nonce: %x", fragment.EncryptionNonce) - - if fragment.MasterKeyVersion != nil { - log.Printf("Fragment key version: %d", *fragment.MasterKeyVersion) - } // Decrypt fragment with current decrypted master key decryptedFragment, err := services.DecryptMasterKey( fragment.Data, - userMasterKey, + userMasterKey, fragment.EncryptionNonce, ) if err != nil { - return fmt.Errorf("failed to decrypt fragment %d for file %d: %w", + log.Printf("Warning: skipping unreadable fragment %d for file %d: %v", fragment.FragmentIndex, file.ID, err) + continue } - log.Printf("Successfully decrypted fragment: %x", decryptedFragment) - // Generate new nonce for fragment newFragmentNonce, err := utils.GenerateNonce() if err != nil { return fmt.Errorf("failed to generate nonce for fragment: %w", err) } - // Re-encrypt with same decrypted master key + // Re-encrypt with same decrypted master key newEncryptedFragment, err := services.EncryptMasterKey( decryptedFragment, - userMasterKey, + userMasterKey, newFragmentNonce, ) if err != nil { return fmt.Errorf("failed to re-encrypt fragment: %w", err) } - log.Printf("Re-encrypted fragment result: %x", newEncryptedFragment) - // Store re-encrypted fragment if err := keyFragmentModel.storage.StoreFragment( fragment.NodeIndex, @@ -871,18 +851,9 @@ func (m *UserModel) ResetPasswordWithFragments( }).Error; err != nil { return fmt.Errorf("failed to update fragment metadata: %w", err) } - - log.Printf("Successfully updated fragment %d for file %d to version %d", - fragment.FragmentIndex, file.ID, newVersion) } } - // Store password history - if err := passwordHistoryModel.AddEntry(user.ID, user.Password); err != nil { - return fmt.Errorf("failed to store password history: %w", err) - } - - // Update user record now := time.Now() updates := map[string]interface{}{ "password": string(hashedNewPassword), @@ -930,7 +901,7 @@ func (m *UserModel) updateKeyFragments( // Decrypt fragment using old master key decryptedFragment, err := services.DecryptMasterKey( fragment.Data, - oldMasterKey, + oldMasterKey, fragment.EncryptionNonce, ) if err != nil { @@ -949,7 +920,7 @@ func (m *UserModel) updateKeyFragments( // Re-encrypt fragment with new decrypted master key newEncryptedFragment, err := services.EncryptMasterKey( decryptedFragment, - decryptedMasterKey, + decryptedMasterKey, newNonce, ) if err != nil { diff --git a/backend/routes/routes.go b/backend/routes/routes.go index 13b7c62..af13300 100644 --- a/backend/routes/routes.go +++ b/backend/routes/routes.go @@ -34,7 +34,9 @@ type EndUserHandlers struct { DeleteFileController *EndUser.DeleteFileController MassDeleteFileController *EndUser.MassDeleteFileController ArchiveFileController *EndUser.ArchiveFileController + UnarchiveFileController *EndUser.UnarchiveFileController MassArchiveController *EndUser.MassArchiveFileController + MassUnarchiveController *EndUser.MassUnarchiveFileController ShareFileController *EndUser.ShareFileController CreateFolderController *EndUser.CreateFolderController ViewFolderController *EndUser.ViewFolderController @@ -47,9 +49,9 @@ type EndUserHandlers struct { FeedbackController *EndUser.FeedbackController } type PremiumUserHandlers struct { - FragmentController *PremiumUser.FragmentController FileRecoveryController *PremiumUser.FileRecoveryController AdvancedShareFileController *PremiumUser.ShareFileController + UpdateBillingController *PremiumUser.UpdateBillingController } type SuperAdminHandlers struct { @@ -67,8 +69,9 @@ type SysAdminHandlers struct { ViewDeletedUserAccountController *SysAdmin.ViewDeletedUserAccountController ViewUserStorageController *SysAdmin.ViewUserStorageController ViewUserAccountDetailsController *SysAdmin.ViewUserAccountDetailsController - ViewFeedbacksController *SysAdmin.ViewFeedbacksController - ViewReportsController *SysAdmin.ViewReportsController + ViewFeedbacksController *SysAdmin.ViewFeedbacksController + ViewReportsController *SysAdmin.ViewReportsController + ViewBillingRecordsController *SysAdmin.ViewBillingRecordsController } func NewRouteHandlers( @@ -81,7 +84,6 @@ func NewRouteHandlers( folderModel *models.FolderModel, fileShareModel *models.FileShareModel, keyFragmentModel *models.KeyFragmentModel, - keyRotationModel *models.KeyRotationModel, serverMasterKeyModel *models.ServerMasterKeyModel, feedbackModel *models.FeedbackModel, encryptionService *services.EncryptionService, @@ -89,13 +91,14 @@ func NewRouteHandlers( compressionService *services.CompressionService, rsService *services.ReedSolomonService, twoFactorService *services.TwoFactorAuthService, + emailService *services.SMTPEmailService, ) *RouteHandlers { superAdminLoginController := SuperAdmin.NewLoginController(userModel) return &RouteHandlers{ - LoginController: controllers.NewLoginController(userModel, billingModel), + LoginController: controllers.NewLoginController(userModel, billingModel, activityLogModel), SuperAdminLoginController: superAdminLoginController, CreateAccountController: controllers.NewCreateAccountController(userModel, passwordHistoryModel), - TwoFactorController: EndUser.NewTwoFactorController(userModel), + TwoFactorController: EndUser.NewTwoFactorController(userModel, twoFactorService), SuperAdminHandlers: &SuperAdminHandlers{ LoginController: superAdminLoginController, CreateSysAdminController: SuperAdmin.NewCreateSysAdminController(userModel), @@ -110,8 +113,9 @@ func NewRouteHandlers( ViewDeletedUserAccountController: SysAdmin.NewViewDeletedUserAccountController(userModel), ViewUserStorageController: SysAdmin.NewViewUserStorageController(userModel), ViewUserAccountDetailsController: SysAdmin.NewViewUserAccountDetailsController(userModel, billingModel), - ViewFeedbacksController: SysAdmin.NewViewFeedbacksController(feedbackModel), - ViewReportsController: SysAdmin.NewViewReportsController(feedbackModel, userModel), + ViewFeedbacksController: SysAdmin.NewViewFeedbacksController(feedbackModel), + ViewReportsController: SysAdmin.NewViewReportsController(feedbackModel, userModel), + ViewBillingRecordsController: SysAdmin.NewViewBillingRecordsController(billingModel), }, EndUserHandlers: &EndUserHandlers{ UploadFileController: EndUser.NewFileController(fileModel, userModel, activityLogModel, encryptionService, shamirService, keyFragmentModel, compressionService, folderModel, rsService, serverMasterKeyModel), @@ -122,22 +126,24 @@ func NewRouteHandlers( DeleteFileController: EndUser.NewDeleteFileController(fileModel), MassDeleteFileController: EndUser.NewMassDeleteFileController(fileModel), ArchiveFileController: EndUser.NewArchiveFileController(fileModel), + UnarchiveFileController: EndUser.NewUnarchiveFileController(fileModel), MassArchiveController: EndUser.NewMassArchiveFileController(fileModel), - ShareFileController: EndUser.NewShareFileController(fileModel, fileShareModel, keyFragmentModel, encryptionService, activityLogModel, rsService, userModel, serverMasterKeyModel), + MassUnarchiveController: EndUser.NewMassUnarchiveFileController(fileModel), + ShareFileController: EndUser.NewShareFileController(fileModel, fileShareModel, keyFragmentModel, encryptionService, activityLogModel, rsService, userModel, serverMasterKeyModel, twoFactorService, emailService, compressionService), CreateFolderController: EndUser.NewCreateFolderController(folderModel, activityLogModel), ViewFolderController: EndUser.NewViewFolderController(folderModel, fileModel), DeleteFolderController: EndUser.NewDeleteFolderController(folderModel, activityLogModel), - PasswordResetController: EndUser.NewPasswordResetController(userModel, passwordHistoryModel, keyRotationModel, keyFragmentModel, fileModel), + PasswordResetController: EndUser.NewPasswordResetController(userModel, passwordHistoryModel, keyFragmentModel, fileModel), ViewStorageController: EndUser.NewViewStorageController(fileModel, userModel), PaymentController: EndUser.NewPaymentController(billingModel), SubscriptionController: EndUser.NewSubscriptionController(billingModel), - ReportController: EndUser.NewReportController(feedbackModel, fileModel), - FeedbackController: EndUser.NewFeedbackController(feedbackModel), + ReportController: EndUser.NewReportController(feedbackModel, fileModel), + FeedbackController: EndUser.NewFeedbackController(feedbackModel), }, PremiumUserHandlers: &PremiumUserHandlers{ - FragmentController: PremiumUser.NewFragmentController(keyFragmentModel, fileModel), FileRecoveryController: PremiumUser.NewFileRecoveryController(fileModel), - AdvancedShareFileController: PremiumUser.NewShareFileController(fileModel, fileShareModel, keyFragmentModel, encryptionService, activityLogModel, rsService, userModel, serverMasterKeyModel), + AdvancedShareFileController: PremiumUser.NewShareFileController(fileModel, fileShareModel, keyFragmentModel, encryptionService, activityLogModel, rsService, userModel, serverMasterKeyModel, twoFactorService, emailService, compressionService), + UpdateBillingController: PremiumUser.NewUpdateBillingController(billingModel), }, } } @@ -157,8 +163,17 @@ func setupPublicRoutes(api *gin.RouterGroup, handlers *RouteHandlers) { api.POST("/login", handlers.LoginController.Login) api.POST("/super-login", handlers.SuperAdminLoginController.Login) api.POST("/register", handlers.CreateAccountController.CreateAccount) + + // Public share routes + api.GET("/files/share/:shareLink", handlers.EndUserHandlers.ShareFileController.AccessShare) api.POST("/files/share/:shareLink", handlers.EndUserHandlers.ShareFileController.AccessShare) + api.POST("/files/share/:shareLink/verify", handlers.EndUserHandlers.ShareFileController.Verify2FAAndDownload) + + // Premium share routes + api.GET("/premium/shares/:shareLink", handlers.PremiumUserHandlers.AdvancedShareFileController.AccessShare) api.POST("/premium/shares/:shareLink", handlers.PremiumUserHandlers.AdvancedShareFileController.AccessShare) + api.POST("/premium/shares/:shareLink/verify", handlers.PremiumUserHandlers.AdvancedShareFileController.Verify2FAAndDownload) + api.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) }) @@ -170,9 +185,11 @@ func setupProtectedRoutes(protected *gin.RouterGroup, handlers *RouteHandlers) { // 2FA routes twoFactor := protected.Group("/2fa") { - twoFactor.POST("/enable", handlers.TwoFactorController.EnableEmailTwoFactor) - twoFactor.POST("/disable", handlers.TwoFactorController.DisableEmailTwoFactor) twoFactor.GET("/status", handlers.TwoFactorController.GetTwoFactorStatus) + twoFactor.POST("/enable/initiate", handlers.TwoFactorController.InitiateEnable2FA) + twoFactor.POST("/enable/verify", handlers.TwoFactorController.VerifyAndEnable2FA) + twoFactor.POST("/disable/initiate", handlers.TwoFactorController.InitiateDisable2FA) + twoFactor.POST("/disable/verify", handlers.TwoFactorController.VerifyAndDisable2FA) } // End User routes should be first as they're most commonly accessed @@ -195,7 +212,6 @@ func setupProtectedRoutes(protected *gin.RouterGroup, handlers *RouteHandlers) { func setupEndUserRoutes(protected *gin.RouterGroup, handlers *EndUserHandlers) { protected.PUT("/reset-password", handlers.PasswordResetController.ResetPassword) - // Existing files routes files := protected.Group("/files") { files.GET("", handlers.ViewFilesController.ListUserFiles) @@ -208,9 +224,10 @@ func setupEndUserRoutes(protected *gin.RouterGroup, handlers *EndUserHandlers) { files.DELETE("/:id", handlers.DeleteFileController.Delete) files.POST("/mass-delete", handlers.MassDeleteFileController.Delete) files.PUT("/:id/archive", handlers.ArchiveFileController.Archive) + files.PUT("/:id/unarchive", handlers.UnarchiveFileController.Unarchive) files.POST("/mass-archive", handlers.MassArchiveController.Archive) + files.POST("/mass-unarchive", handlers.MassUnarchiveController.Unarchive) files.POST("/:id/share", handlers.ShareFileController.CreateShare) - files.GET("/share/:shareLink", handlers.ShareFileController.AccessShare) } folders := protected.Group("/folders") @@ -231,26 +248,22 @@ func setupEndUserRoutes(protected *gin.RouterGroup, handlers *EndUserHandlers) { payment.POST("/cancel", handlers.SubscriptionController.CancelSubscription) } feedback := protected.Group("/feedback") - { - feedback.POST("", handlers.FeedbackController.SubmitFeedback) - feedback.GET("", handlers.FeedbackController.GetUserFeedback) - feedback.GET("/categories", handlers.FeedbackController.GetFeedbackCategories) - } + { + feedback.POST("", handlers.FeedbackController.SubmitFeedback) + feedback.GET("", handlers.FeedbackController.GetUserFeedback) + feedback.GET("/categories", handlers.FeedbackController.GetFeedbackCategories) + } - reports := protected.Group("/reports") - { - reports.POST("/file/:id", handlers.ReportController.ReportFile) - reports.POST("/share/:shareLink", handlers.ReportController.ReportShare) - reports.GET("", handlers.ReportController.GetUserReports) - } + reports := protected.Group("/reports") + { + reports.POST("/file/:id", handlers.ReportController.ReportFile) + reports.POST("/share/:shareLink", handlers.ReportController.ReportShare) + reports.GET("", handlers.ReportController.GetUserReports) + } } func setupPremiumUserRoutes(premium *gin.RouterGroup, handlers *PremiumUserHandlers) { - // Fragment management routes - fragments := premium.Group("/fragments") - { - fragments.GET("/files/:fileId", handlers.FragmentController.GetUserFragments) - } + recovery := premium.Group("/recovery") { recovery.GET("/files", handlers.FileRecoveryController.ListRecoverableFiles) @@ -260,6 +273,11 @@ func setupPremiumUserRoutes(premium *gin.RouterGroup, handlers *PremiumUserHandl { shares.POST("/files/:id", handlers.AdvancedShareFileController.CreateShare) } + billing := premium.Group("/billing") + { + billing.GET("/details", handlers.UpdateBillingController.GetBillingDetails) + billing.PUT("/details", handlers.UpdateBillingController.UpdateBillingDetails) + } } func setupSuperAdminRoutes(superAdmin *gin.RouterGroup, handlers *SuperAdminHandlers) { @@ -285,18 +303,24 @@ func setupSysAdminRoutes(sysAdmin *gin.RouterGroup, handlers *SysAdminHandlers) sysAdmin.GET("/storage/stats", handlers.ViewUserStorageController.GetStorageStats) feedback := sysAdmin.Group("/feedback") - { - feedback.GET("", handlers.ViewFeedbacksController.GetAllFeedbacks) - feedback.GET("/:id", handlers.ViewFeedbacksController.GetFeedback) - feedback.PUT("/:id/status", handlers.ViewFeedbacksController.UpdateFeedbackStatus) - feedback.GET("/stats", handlers.ViewFeedbacksController.GetFeedbackStats) - } + { + feedback.GET("", handlers.ViewFeedbacksController.GetAllFeedbacks) + feedback.GET("/:id", handlers.ViewFeedbacksController.GetFeedback) + feedback.PUT("/:id/status", handlers.ViewFeedbacksController.UpdateFeedbackStatus) + feedback.GET("/stats", handlers.ViewFeedbacksController.GetFeedbackStats) + } - reports := sysAdmin.Group("/reports") - { - reports.GET("", handlers.ViewReportsController.GetAllReports) - reports.GET("/:id", handlers.ViewReportsController.GetReportDetails) - reports.PUT("/:id/status", handlers.ViewReportsController.UpdateReportStatus) - reports.GET("/stats", handlers.ViewReportsController.GetReportStats) - } + reports := sysAdmin.Group("/reports") + { + reports.GET("", handlers.ViewReportsController.GetAllReports) + reports.GET("/:id", handlers.ViewReportsController.GetReportDetails) + reports.PUT("/:id/status", handlers.ViewReportsController.UpdateReportStatus) + reports.GET("/stats", handlers.ViewReportsController.GetReportStats) + } + billing := sysAdmin.Group("/billing") + { + billing.GET("/records", handlers.ViewBillingRecordsController.GetAllBillingRecords) + billing.GET("/stats", handlers.ViewBillingRecordsController.GetBillingStats) + billing.GET("/expiring", handlers.ViewBillingRecordsController.GetExpiringSubscriptions) + } } diff --git a/backend/services/email.go b/backend/services/email.go index b9a05a2..2131154 100644 --- a/backend/services/email.go +++ b/backend/services/email.go @@ -57,32 +57,33 @@ func validateConfig(config SMTPConfig) error { } func (s *SMTPEmailService) connect() error { - s.mu.Lock() - defer s.mu.Unlock() - - // Establish a plaintext connection first - conn, err := smtp.Dial(fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)) - if err != nil { - return fmt.Errorf("failed to establish SMTP connection: %w", err) - } - - // Send the STARTTLS command to upgrade the connection to TLS - if err := conn.StartTLS(&tls.Config{ - ServerName: s.config.Host, - MinVersion: tls.VersionTLS12, - }); err != nil { - conn.Close() - return fmt.Errorf("failed to upgrade to TLS: %w", err) - } - - // Authenticate with the SMTP server - if err := conn.Auth(s.auth); err != nil { - conn.Close() - return fmt.Errorf("authentication failed: %w", err) - } - - s.client = conn - return nil + s.mu.Lock() + defer s.mu.Unlock() + + tlsConfig := &tls.Config{ + ServerName: s.config.Host, + MinVersion: tls.VersionTLS12, + } + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig) + if err != nil { + return fmt.Errorf("failed to establish TLS connection: %w", err) + } + + client, err := smtp.NewClient(conn, s.config.Host) + if err != nil { + conn.Close() + return fmt.Errorf("failed to create SMTP client: %w", err) + } + + // Authenticate + if err := client.Auth(s.auth); err != nil { + client.Close() + return fmt.Errorf("authentication failed: %w", err) + } + + s.client = client + return nil } func (s *SMTPEmailService) SendEmail(to, subject, body string) error { if err := s.connect(); err != nil { diff --git a/backend/services/two_factor_auth.go b/backend/services/two_factor_auth.go index 7b9065c..1ba1ccd 100644 --- a/backend/services/two_factor_auth.go +++ b/backend/services/two_factor_auth.go @@ -109,9 +109,9 @@ func (s *TwoFactorAuthService) SendTwoFactorToken(userID uint, email string) err body := fmt.Sprintf(`Hello, Your two-factor authentication code is: -**%s** +%s -This code will expire in **10 minutes**. Please use it to complete your login process. +This code will expire in 10 minutes. Please use it to complete your login process. If you didn't request this code, please ignore this email or contact our support team immediately. @@ -120,6 +120,40 @@ Safesplit team`, token) return s.emailSender.SendEmail(email, subject, body) } +func (s *TwoFactorAuthService) SendShareVerificationToken(shareID uint, email, fileName string) error { + if !s.rateLimiter.Allow(shareID) { + return errors.New("rate limit exceeded") + } + + token, err := generateToken() + if err != nil { + return fmt.Errorf("failed to generate token: %w", err) + } + + s.mu.Lock() + s.tokens[shareID] = &TwoFactorToken{ + Token: token, + ExpiresAt: time.Now().Add(tokenExpiry), + } + s.attempts[shareID] = 0 + s.mu.Unlock() + + subject := "Verify Your Access to Shared File" + body := fmt.Sprintf(`Hello, + +A file "%s" has been shared with you. To access this file, please use the following verification code: + +Verification Code: %s + +This code will expire in 10 minutes. Please use it along with your password to access the shared file. + +If you didn't expect to receive this file share, please ignore this email. + +Best regards, +SafeSplit Team`, fileName, token) + + return s.emailSender.SendEmail(email, subject, body) +} func (s *TwoFactorAuthService) VerifyToken(userID uint, token string) error { s.mu.Lock() diff --git a/create_test_users.sql b/create_test_users.sql index dcbb994..6559048 100644 --- a/create_test_users.sql +++ b/create_test_users.sql @@ -5,7 +5,7 @@ CREATE PROCEDURE create_test_data() BEGIN -- Clear existing data SET FOREIGN_KEY_CHECKS = 0; - TRUNCATE TABLE feedback; + TRUNCATE TABLE feedbacks; TRUNCATE TABLE user_subscriptions; TRUNCATE TABLE subscription_plans; TRUNCATE TABLE share_access_logs; @@ -14,7 +14,7 @@ BEGIN TRUNCATE TABLE key_fragments; TRUNCATE TABLE files; TRUNCATE TABLE folders; - TRUNCATE TABLE key_rotation_history; + TRUNCATE TABLE key_rotation_histories; TRUNCATE TABLE password_history; TRUNCATE TABLE server_master_keys; TRUNCATE TABLE users; @@ -71,7 +71,7 @@ BEGIN ), ( 'sys_admin', - 'sys_admin@example.com', + 'sys_admin@safesplit.xyz', '$2a$10$b.WsKp9GR.8pcdQjxMggGeCtTL7nvuc1oW2LfZu0FrM5SLv3dhkge', UNHEX('6293E61742A9A26D16ABC91564FE26157923B855F58535AD73E9720C60F94C22'), -- salt UNHEX('6293E61742A9A26D16ABC91564FE2615'), -- nonce @@ -85,7 +85,7 @@ BEGIN ), ( 'super_admin', - 'super_admin@example.com', + 'super_admin@safesplit.xyz', '$2a$10$b.WsKp9GR.8pcdQjxMggGeCtTL7nvuc1oW2LfZu0FrM5SLv3dhkge', UNHEX('B5F83100C6B4F1FF3865A9FB3A32CBB1EF4770A734026D0AE0451737109750C7'), -- salt UNHEX('B5F83100C6B4F1FF3865A9FB3A32CBB1'), -- nonce diff --git a/database-setup.sql b/database-setup.sql index a4697a5..97e4423 100644 --- a/database-setup.sql +++ b/database-setup.sql @@ -38,7 +38,7 @@ CREATE TABLE users ( CREATE TABLE server_master_keys ( id INT AUTO_INCREMENT PRIMARY KEY, key_id VARCHAR(64) NOT NULL UNIQUE, -- Unique identifier for the key - encrypted_key BINARY(64) NOT NULL, -- Encrypted server master key + encrypted_key BINARY(32) NOT NULL, -- Encrypted server master key key_nonce BINARY(16) NOT NULL, -- Nonce for key encryption is_active BOOLEAN DEFAULT TRUE, -- Whether this is the current active key created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, @@ -48,18 +48,6 @@ CREATE TABLE server_master_keys ( INDEX idx_key_id (key_id) ); --- Key rotation history table --- Purpose: Tracks key rotation events -CREATE TABLE key_rotation_histories ( - id INT AUTO_INCREMENT PRIMARY KEY, - user_id INT NOT NULL, - old_key_version INT NOT NULL, - new_key_version INT NOT NULL, - rotation_type ENUM('automatic', 'manual', 'forced') NOT NULL, - rotated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - INDEX idx_user_rotation (user_id, rotated_at) -); -- Password history table CREATE TABLE password_history ( @@ -170,27 +158,19 @@ CREATE TABLE file_shares ( download_count INT DEFAULT 0, -- Current downloads is_active BOOLEAN DEFAULT TRUE, -- Share status created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + share_type VARCHAR(20) NOT NULL DEFAULT 'normal', -- Share type (normal/recipient) + email VARCHAR(255) NULL, -- Recipient email for recipient shares FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE, FOREIGN KEY (shared_by) REFERENCES users(id) ON DELETE CASCADE ); --- Share access logs table -CREATE TABLE share_access_logs ( - id INT AUTO_INCREMENT PRIMARY KEY, - share_id INT NOT NULL, -- Associated share - ip_address VARCHAR(45) NOT NULL, -- Access IP - access_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - status ENUM('success', 'failed') NOT NULL, -- Access result - failure_reason VARCHAR(255), -- Failure reason - FOREIGN KEY (share_id) REFERENCES file_shares(id) ON DELETE CASCADE -); -- Activity logs table CREATE TABLE activity_logs ( id INT AUTO_INCREMENT PRIMARY KEY, user_id INT NOT NULL, -- User performing action activity_type ENUM('upload', 'download', 'delete', 'share', 'login', - 'logout', 'archive', 'restore', 'encrypt', 'decrypt') NOT NULL, + 'logout', 'archive', 'restore', 'encrypt', 'decrypt','unarchive') NOT NULL, file_id INT, -- Associated file folder_id INT, -- Associated folder ip_address VARCHAR(45), -- User's IP @@ -214,19 +194,7 @@ CREATE TABLE subscription_plans ( updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ); --- User subscriptions table -CREATE TABLE user_subscriptions ( - id INT AUTO_INCREMENT PRIMARY KEY, - user_id INT NOT NULL, -- Subscribed user - plan_id INT NOT NULL, -- Selected plan - start_date TIMESTAMP NOT NULL, -- Subscription start - end_date TIMESTAMP NOT NULL, -- Subscription end - status ENUM('active', 'cancelled', 'expired') NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, - FOREIGN KEY (plan_id) REFERENCES subscription_plans(id) -); + -- Feedback table CREATE TABLE feedbacks ( @@ -250,8 +218,8 @@ CREATE INDEX idx_file_shares_link ON file_shares(share_link); CREATE INDEX idx_share_access_logs_share_id ON share_access_logs(share_id); CREATE INDEX idx_activity_logs_user_id ON activity_logs(user_id); CREATE INDEX idx_activity_logs_created_at ON activity_logs(created_at); -CREATE INDEX idx_feedback_user_id ON feedback(user_id); -CREATE INDEX idx_feedback_status ON feedback(status); +CREATE INDEX idx_feedback_user_id ON feedbacks(user_id); +CREATE INDEX idx_feedback_status ON feedbacks(status); CREATE INDEX idx_files_is_shared ON files(is_shared); CREATE INDEX idx_files_key_version ON files(master_key_version); CREATE INDEX idx_key_fragments_key_version ON key_fragments(master_key_version); diff --git a/frontend/src/App.js b/frontend/src/App.js index 333037a..8e394c6 100644 --- a/frontend/src/App.js +++ b/frontend/src/App.js @@ -23,7 +23,6 @@ function App() { setUser(null); }; - // Helper function to check if user has access to current route const isRouteAccessible = (allowedRoles) => { return user && allowedRoles.includes(user.role); }; @@ -31,7 +30,6 @@ function App() { return (
- {/* Show NavigationBar only when user is not logged in */} {!user && } @@ -46,39 +44,24 @@ function App() { {/* Authentication Routes */} - ) : ( - - ) - } + element={user ? : } /> - ) : ( - - ) - } + element={user ? : } /> - ) : ( - - ) - } + element={user ? : } /> - } /> - } /> + {/* Share Access Routes */} + } /> + } /> + } /> + } /> {/* Protected Routes */} } /> - - {/* Catch all route */} } /> diff --git a/frontend/src/components/BillingPage.js b/frontend/src/components/BillingPage.js index d4033f4..9e720e0 100644 --- a/frontend/src/components/BillingPage.js +++ b/frontend/src/components/BillingPage.js @@ -78,7 +78,7 @@ const BillingPage = ({ user, onUpgradeSuccess }) => { try { validateCard(); const token = localStorage.getItem('token'); - + const paymentResponse = await fetch('http://localhost:8080/api/payment/upgrade', { method: 'POST', headers: { @@ -98,14 +98,26 @@ const BillingPage = ({ user, onUpgradeSuccess }) => { countryCode: billingInfo.countryCode }) }); - + if (!paymentResponse.ok) { const errorData = await paymentResponse.json(); throw new Error(errorData.error || 'Payment processing failed'); } - + + // Update local user data + const currentUser = JSON.parse(localStorage.getItem('user')); + const updatedUser = { + ...currentUser, + role: 'premium_user', + subscription_status: 'premium' + }; + localStorage.setItem('user', JSON.stringify(updatedUser)); + + // Call the upgrade success callback if (onUpgradeSuccess) onUpgradeSuccess(); - navigate('/premium-dashboard?upgraded=true'); + + // Force redirect and reload + window.location.href = '/premium-dashboard?upgraded=true'; } catch (err) { setError(err.message); } finally { diff --git a/frontend/src/components/EndUser/ArchiveFileAction.js b/frontend/src/components/EndUser/ArchiveFileAction.js index 7e25811..4cac405 100644 --- a/frontend/src/components/EndUser/ArchiveFileAction.js +++ b/frontend/src/components/EndUser/ArchiveFileAction.js @@ -1,7 +1,7 @@ import React, { useState } from 'react'; import { Archive, Loader } from 'lucide-react'; -const ArchiveFileAction = ({ file, selectedFiles = [], onRefresh }) => { +const ArchiveFileAction = ({ file, selectedFiles = [], onRefresh, onClearSelection }) => { const [isArchiving, setIsArchiving] = useState(false); const handleArchive = async () => { @@ -49,6 +49,7 @@ const ArchiveFileAction = ({ file, selectedFiles = [], onRefresh }) => { } } + onClearSelection?.(); onRefresh?.(); } catch (error) { console.error('Archive error:', error); diff --git a/frontend/src/components/EndUser/DeleteFileAction.js b/frontend/src/components/EndUser/DeleteFileAction.js index 23db347..b4c25dd 100644 --- a/frontend/src/components/EndUser/DeleteFileAction.js +++ b/frontend/src/components/EndUser/DeleteFileAction.js @@ -1,23 +1,29 @@ -import React from 'react'; -import { Trash2 } from 'lucide-react'; +import React, { useState } from 'react'; +import { Trash2, Loader } from 'lucide-react'; + +const DeleteFileAction = ({ file, selectedFiles = [], onRefresh, onClearSelection }) => { + const [isDeleting, setIsDeleting] = useState(false); -const DeleteFileAction = ({ file, selectedFiles = [] }) => { const handleDelete = async () => { const files = selectedFiles.length > 0 ? selectedFiles : [file]; const confirmMessage = `Are you sure you want to delete ${files.length > 1 ? 'these files' : 'this file'}?`; if (!window.confirm(confirmMessage)) return; + setIsDeleting(true); + try { const token = localStorage.getItem('token'); if (files.length === 1) { - await fetch(`http://localhost:8080/api/files/${files[0].id}`, { + const response = await fetch(`http://localhost:8080/api/files/${files[0].id}`, { method: 'DELETE', headers: { 'Authorization': `Bearer ${token}` }, }); + + if (!response.ok) throw new Error('Failed to delete file'); } else { - await fetch('http://localhost:8080/api/files/mass-delete', { + const response = await fetch('http://localhost:8080/api/files/mass-delete', { method: 'POST', headers: { 'Authorization': `Bearer ${token}`, @@ -27,20 +33,31 @@ const DeleteFileAction = ({ file, selectedFiles = [] }) => { file_ids: files.map(f => f.id) }) }); + + if (!response.ok) throw new Error('Failed to delete files'); } - window.location.reload(); + onClearSelection?.(); + onRefresh?.(); } catch (error) { console.error('Delete error:', error); + alert(error.message || 'Failed to delete file(s)'); + } finally { + setIsDeleting(false); } }; return ( ); diff --git a/frontend/src/components/EndUser/DownloadFileAction.js b/frontend/src/components/EndUser/DownloadFileAction.js index 1e4993b..cfe2ff4 100644 --- a/frontend/src/components/EndUser/DownloadFileAction.js +++ b/frontend/src/components/EndUser/DownloadFileAction.js @@ -1,13 +1,12 @@ import React from 'react'; import { Download } from 'lucide-react'; -const DownloadFileAction = ({ file, selectedFiles = [] }) => { +const DownloadFileAction = ({ file, selectedFiles = [], onClearSelection, onClose }) => { const handleDownload = async () => { try { const token = localStorage.getItem('token'); const filesToDownload = selectedFiles.length > 0 ? selectedFiles : [file]; - // If multiple files selected, get download status first if (selectedFiles.length > 0) { const statusResponse = await fetch('http://localhost:8080/api/files/mass-download', { method: 'POST', @@ -27,7 +26,6 @@ const DownloadFileAction = ({ file, selectedFiles = [] }) => { result => result.status === 'success' ); - // Download each available file for (const fileStatus of availableFiles) { try { const response = await fetch(`http://localhost:8080/api/files/mass-download/${fileStatus.file_id}`, { @@ -52,7 +50,6 @@ const DownloadFileAction = ({ file, selectedFiles = [] }) => { } } } else { - // Single file download const response = await fetch(`http://localhost:8080/api/files/${file.id}/download`, { headers: { 'Authorization': `Bearer ${token}`, @@ -71,8 +68,12 @@ const DownloadFileAction = ({ file, selectedFiles = [] }) => { window.URL.revokeObjectURL(url); document.body.removeChild(a); } + + onClearSelection?.(); + onClose?.(); } catch (error) { console.error('Download error:', error); + alert(error.message || 'Failed to download file(s)'); } }; diff --git a/frontend/src/components/EndUser/FileActions.js b/frontend/src/components/EndUser/FileActions.js index e73d593..ef2aac7 100644 --- a/frontend/src/components/EndUser/FileActions.js +++ b/frontend/src/components/EndUser/FileActions.js @@ -1,13 +1,39 @@ -import React, { useState } from 'react'; +import React, { useState, useRef, useEffect } from 'react'; import { Download, Trash2, Share2, Archive, MoreVertical, Check } from 'lucide-react'; import DownloadFileAction from './DownloadFileAction'; import DeleteFileAction from './DeleteFileAction'; import ShareFileAction from './ShareFileAction'; import ArchiveFileAction from './ArchiveFileAction'; -import ReportFileAction from './ReportFileAction'; +import UnarchiveFileAction from './UnarchiveFileAction'; +import ReportFileAction from './ReportFileAction'; -const FileActions = ({ file, user, onRefresh, onAction, isSelectable = false, selected = false, onSelect, selectedFiles = [] }) => { +const FileActions = ({ + file, + user, + onRefresh, + onAction, + isSelectable = false, + selected = false, + onSelect, + selectedFiles = [], + onClearSelection +}) => { const [showActions, setShowActions] = useState(false); + const actionsRef = useRef(null); + + useEffect(() => { + const handleClickOutside = (event) => { + if (actionsRef.current && !actionsRef.current.contains(event.target)) { + setShowActions(false); + } + }; + + document.addEventListener('mousedown', handleClickOutside); + + return () => { + document.removeEventListener('mousedown', handleClickOutside); + }; + }, []); const handleClick = (e) => { if (isSelectable) { @@ -18,8 +44,16 @@ const FileActions = ({ file, user, onRefresh, onAction, isSelectable = false, se } }; + const handleClose = () => { + setShowActions(false); + }; + + const allFilesArchived = selectedFiles.length > 0 + ? selectedFiles.every(f => f.is_archived) + : file.is_archived; + return ( -
+
)}
diff --git a/frontend/src/components/EndUser/Settings.js b/frontend/src/components/EndUser/Settings.js index 09cb937..dc34c79 100644 --- a/frontend/src/components/EndUser/Settings.js +++ b/frontend/src/components/EndUser/Settings.js @@ -7,16 +7,40 @@ import TwoFactorSettings from './TwoFactorAuthentication'; const Settings = ({ user: initialUser, onUserUpdate }) => { const [activeTab, setActiveTab] = useState('account'); const [currentUser, setCurrentUser] = useState(initialUser?.data?.user || {}); - const [billingProfile, setBillingProfile] = useState(initialUser?.data?.billing_profile || {}); + const [billingProfile, setBillingProfile] = useState(null); const [loading, setLoading] = useState(false); const [error, setError] = useState(''); const [showCancelModal, setShowCancelModal] = useState(false); const [cancelInfo, setCancelInfo] = useState(null); + const [isEditing, setIsEditing] = useState(false); + const [formData, setFormData] = useState({ + billing_name: '', + billing_email: '', + billing_address: '', + country_code: '', + default_payment_method: 'credit_card', + billing_cycle: 'monthly', + currency: 'USD' + }); useEffect(() => { fetchUserData(); }, []); + useEffect(() => { + if (billingProfile) { + setFormData({ + billing_name: billingProfile.billing_name || '', + billing_email: billingProfile.billing_email || '', + billing_address: billingProfile.billing_address || '', + country_code: billingProfile.country_code || '', + default_payment_method: billingProfile.default_payment_method || 'credit_card', + billing_cycle: billingProfile.billing_cycle || 'monthly', + currency: billingProfile.currency || 'USD' + }); + } + }, [billingProfile]); + const fetchUserData = async () => { try { const token = localStorage.getItem('token'); @@ -29,7 +53,9 @@ const Settings = ({ user: initialUser, onUserUpdate }) => { if (response.ok) { const data = await response.json(); setCurrentUser(data.data.user); - setBillingProfile(data.data.billing_profile); + if (data.data.billing_profile) { + setBillingProfile(data.data.billing_profile); + } if (onUserUpdate) onUserUpdate(data); } } catch (error) { @@ -66,6 +92,38 @@ const Settings = ({ user: initialUser, onUserUpdate }) => { } }; + const handleUpdateBilling = async (e) => { + e.preventDefault(); + setLoading(true); + setError(''); + + try { + const token = localStorage.getItem('token'); + const response = await fetch('http://localhost:8080/api/premium/billing/details', { + method: 'PUT', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(formData) + }); + + const data = await response.json(); + + if (!response.ok) { + throw new Error(data.error || 'Failed to update billing details'); + } + + setBillingProfile(data.data); + setIsEditing(false); + await fetchUserData(); + } catch (err) { + setError(err.message); + } finally { + setLoading(false); + } + }; + const CancelModal = () => (
@@ -89,11 +147,173 @@ const Settings = ({ user: initialUser, onUserUpdate }) => {
); - const formatBytes = (bytes) => { - const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']; - if (bytes === 0) return '0 Bytes'; - const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024))); - return Math.round(bytes / Math.pow(1024, i), 2) + ' ' + sizes[i]; + const renderAccountDetails = () => { + return ( +
+

Username: {currentUser.username}

+

Email: {currentUser.email}

+

Subscription Plan: {currentUser.subscription_status}

+ + {billingProfile && ( + <> +

Billing Cycle: {billingProfile.billing_cycle}

+

+ Next Billing Date: {' '} + {billingProfile.next_billing_date + ? new Date(billingProfile.next_billing_date).toLocaleDateString() + : 'N/A' + } +

+ + )} + + {currentUser.subscription_status === 'premium' && + billingProfile?.billing_status === 'active' && ( +
+ +
+ )} + + {billingProfile?.billing_status === 'cancelled' && ( +
+

+ Your subscription is cancelled and will end on{' '} + {new Date(billingProfile.next_billing_date).toLocaleDateString()} +

+
+ )} +
+ ); + }; + + const renderBillingDetails = () => { + if (!isEditing) { + return ( +
+
+

Billing Details

+ +
+ + {billingProfile ? ( +
+

Billing Name: {billingProfile.billing_name}

+

Billing Email: {billingProfile.billing_email}

+

Billing Address: {billingProfile.billing_address}

+

Country: {billingProfile.country_code}

+

Payment Method: {billingProfile.default_payment_method}

+

Billing Cycle: {billingProfile.billing_cycle}

+

Currency: {billingProfile.currency}

+
+ ) : ( +

No billing profile found. Click Edit to set up billing details.

+ )} +
+ ); + } + + return ( +
+

Edit Billing Details

+ +
+ + setFormData(prev => ({ ...prev, billing_name: e.target.value }))} + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500" + required + /> +
+ +
+ + setFormData(prev => ({ ...prev, billing_email: e.target.value }))} + className="mt-1 block w-full rounded-md border-gray-300 shadow-sm focus:border-blue-500 focus:ring-blue-500" + required + /> +
+ +
+ +